# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Orbax checkpoint writer and a loader.
- CheckpointWriter: saves checkpoints with preservation/decision policies
- CheckpointLoader: restores checkpoints (latest or specific step)
"""
from __future__ import annotations
import jax
import jax.numpy as jnp
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Tuple, cast, TYPE_CHECKING
from functools import partial
import inspect
import orbax.checkpoint as ocp
from orbax.checkpoint.checkpoint_managers import (
preservation_policy as preservation_policy_lib,
)
from orbax.checkpoint.checkpoint_managers import (
save_decision_policy as save_decision_policy_lib,
)
from ..state import State
from ..system import System
if TYPE_CHECKING:
from ..rl.models import Model
from ..utils import decode_callable
[docs]
@dataclass(slots=True, weakref_slot=True)
class CheckpointWriter:
"""
Thin wrapper around Orbax checkpoint saving.
"""
directory: Path | str = Path("./checkpoints")
"""
The base directory where checkpoints will be saved.
"""
max_to_keep: int | None = None
"""
Keep the last max_to_keep checkpoints. If None, everything is save.
"""
save_every: int = 1
"""
How often to write; writes on every ``save_every``-th call to :meth:`save`.
"""
checkpointer: ocp.CheckpointManager = field(init=False)
"""
Orbax checkpoint manager for saving the checkpoints.
"""
def __post_init__(self):
self.directory = Path(self.directory).resolve()
self.directory = cast(
Path, ocp.test_utils.erase_and_create_empty(self.directory)
)
self.save_every = int(self.save_every)
self.max_to_keep = (
int(self.max_to_keep) if self.max_to_keep is not None else None
)
options = ocp.CheckpointManagerOptions(
save_decision_policy=save_decision_policy_lib.FixedIntervalPolicy(
self.save_every
),
preservation_policy=preservation_policy_lib.LatestN(self.max_to_keep),
)
self.checkpointer = ocp.CheckpointManager(
self.directory,
options=options,
)
[docs]
@partial(jax.named_call, name="CheckpointWriter.save")
def save(self, state: "State", system: "System") -> None:
"""
Save a checkpoint for the provided state/system at a given step.
Parameters
----------
state : State
The current state of the simulation.
system : System
The current system configuration.
"""
system_metadata = dict(
dim=state.dim,
integrator_type=system.integrator.type_name,
collider_type=system.collider.type_name,
domain_type=system.domain.type_name,
force_model_type=system.force_model.type_name,
)
self.checkpointer.save(
int(system.step_count),
args=ocp.args.Composite(
state=ocp.args.StandardSave(state),
system=ocp.args.StandardSave(system),
state_metadata=ocp.args.JsonSave(dict(shape=tuple(state.pos.shape))),
system_metadata=ocp.args.JsonSave(system_metadata),
),
)
[docs]
@partial(jax.named_call, name="CheckpointWriter.block_until_ready")
def block_until_ready(self):
"""
Wait for the checkpointer to finish.
"""
self.checkpointer.wait_until_finished()
[docs]
@partial(jax.named_call, name="CheckpointWriter.close")
def close(self) -> None:
"""
Wait for the checkpointer to finish and close it.
"""
try:
self.checkpointer.wait_until_finished()
finally:
self.checkpointer.close()
def __del__(self):
try:
self.close()
except Exception:
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.close()
return False
[docs]
@dataclass(slots=True)
class CheckpointLoader:
"""
Thin wrapper around Orbax checkpoint restoring for jaxdem.state and jaxdem.system.
"""
directory: Path = Path("./checkpoints")
"""
The base directory where checkpoints will be saved.
"""
checkpointer: ocp.CheckpointManager = field(init=False)
"""
Orbax checkpoint manager for saving the checkpoints.
"""
def __post_init__(self):
self.directory = Path(self.directory).resolve()
options = ocp.CheckpointManagerOptions()
self.checkpointer = ocp.CheckpointManager(
self.directory,
options=options,
)
[docs]
@partial(jax.named_call, name="CheckpointLoader.load")
def load(
self,
step: Optional[int] = None,
) -> Tuple[State, System]:
"""
Restore a checkpoint.
Parameters
----------
step : Optional[int]
- If None, load the latest checkpoint.
- Otherwise, load the specified step.
Returns
-------
Tuple[State, System]
A tuple containing the restored `State` and `System`.
"""
if step is None:
step = self.checkpointer.latest_step()
if step is None:
raise FileNotFoundError(f"No checkpoints found in: {self.directory}")
if step not in self.checkpointer.all_steps():
raise FileNotFoundError(
f"step={step} checkpoints not found in: {self.directory}. Available steps: {self.checkpointer.all_steps()}"
)
metadata = self.checkpointer.restore(
step,
args=ocp.args.Composite(
state_metadata=ocp.args.JsonRestore(),
system_metadata=ocp.args.JsonRestore(),
),
)
state_target = State.create(jnp.zeros(tuple(metadata.state_metadata["shape"])))
system_target = System.create(**metadata.system_metadata)
result = self.checkpointer.restore(
step,
args=ocp.args.Composite(
state=ocp.args.StandardRestore(state_target),
system=ocp.args.StandardRestore(system_target),
),
)
return result.state, result.system
[docs]
@partial(jax.named_call, name="CheckpointLoader.block_until_ready")
def block_until_ready(self):
self.checkpointer.wait_until_finished()
[docs]
@partial(jax.named_call, name="CheckpointLoader.close")
def close(self) -> None:
try:
self.checkpointer.wait_until_finished()
finally:
self.checkpointer.close()
def __del__(self):
try:
self.close()
except Exception:
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.close()
return False
[docs]
@dataclass(slots=True)
class CheckpointModelWriter:
"""
Thin wrapper around Orbax checkpoint saving for jaxdem.rl.models.Model.
"""
directory: Path | str = Path("./checkpoints")
"""
The base directory where checkpoints will be saved.
"""
max_to_keep: int | None = None
"""
Keep the last max_to_keep checkpoints. If None, everything is save.
"""
save_every: int = 1
"""
How often to write; writes on every ``save_every``-th call to :meth:`save`.
"""
checkpointer: ocp.CheckpointManager = field(init=False)
"""
Orbax checkpoint manager for saving the checkpoints.
"""
clean: bool = True
"""
Wether to clean the directory
"""
def __post_init__(self):
self.directory = Path(self.directory).resolve()
self.directory.mkdir(parents=True, exist_ok=True)
if self.clean:
self.directory = cast(
Path, ocp.test_utils.erase_and_create_empty(self.directory)
)
self.save_every = int(self.save_every)
self.max_to_keep = (
int(self.max_to_keep) if self.max_to_keep is not None else None
)
options = ocp.CheckpointManagerOptions(
save_decision_policy=save_decision_policy_lib.FixedIntervalPolicy(
self.save_every
),
preservation_policy=preservation_policy_lib.LatestN(self.max_to_keep),
)
self.checkpointer = ocp.CheckpointManager(
self.directory,
options=options,
)
[docs]
@partial(jax.named_call, name="CheckpointModelWriter.save")
def save(self, model: "Model", step: int) -> None:
"""
Save model at a step: stores model_state and JSON metadata.
Assumes model.metadata includes JSON-serializable fields. We add model_type.
"""
from flax import nnx
model_metadata = model.metadata
model_metadata["model_type"] = model.type_name
graphdef, state = nnx.split(model)
self.checkpointer.save(
int(step),
args=ocp.args.Composite(
model_state=ocp.args.StandardSave(state),
model_metadata=ocp.args.JsonSave(model_metadata),
),
)
[docs]
@partial(jax.named_call, name="CheckpointModelWriter.block_until_ready")
def block_until_ready(self) -> None:
self.checkpointer.wait_until_finished()
[docs]
@partial(jax.named_call, name="CheckpointModelWriter.close")
def close(self) -> None:
try:
self.checkpointer.wait_until_finished()
finally:
self.checkpointer.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.close()
return False
def __del__(self):
try:
self.close()
except Exception:
pass
[docs]
@dataclass(slots=True)
class CheckpointModelLoader:
"""
Thin wrapper around Orbax checkpoint restoring for jaxdem.rl.models.Model.
"""
directory: Path = Path("./checkpoints")
"""
The base directory where checkpoints will be saved.
"""
checkpointer: ocp.CheckpointManager = field(init=False)
"""
Orbax checkpoint manager for saving the checkpoints.
"""
def __post_init__(self):
self.directory = Path(self.directory).resolve()
options = ocp.CheckpointManagerOptions()
self.checkpointer = ocp.CheckpointManager(
self.directory,
options=options,
)
[docs]
@partial(jax.named_call, name="CheckpointModelLoader.load")
def load(self, step: int | None = None):
"""
Load a model from a given step (or the latest if None).
"""
from flax import nnx
from ..rl.models import Model
from ..rl.actionSpaces import ActionSpace
if step is None:
step = self.checkpointer.latest_step()
if step is None:
raise FileNotFoundError(f"No checkpoints found in: {self.directory}")
if step not in self.checkpointer.all_steps():
raise FileNotFoundError(
f"step={step} checkpoints not found in: {self.directory}. Available steps: {self.checkpointer.all_steps()}"
)
model_metadata = self.checkpointer.restore(
step,
args=ocp.args.Composite(
model_metadata=ocp.args.JsonRestore(),
),
)
model_metadata = model_metadata.model_metadata
action_space = ActionSpace.create(
model_metadata["action_space_type"], **model_metadata["action_space_kws"]
)
used_keys = [
"action_space_type",
"action_space_kws",
"reset_shape",
"activation",
"model_type",
]
rngs = nnx.Rngs(0)
activation = decode_callable(model_metadata["activation"])
model_type = model_metadata["model_type"]
reset_shape = model_metadata.get("reset_shape", (1,))
model_metadata = {
key: value for key, value in model_metadata.items() if key not in used_keys
}
if "cell_type" in model_metadata:
model_metadata["cell_type"] = decode_callable(model_metadata["cell_type"])
model = Model.create(
model_type,
**model_metadata,
key=rngs,
action_space=action_space,
activation=activation,
)
graphdef, state = nnx.split(model)
result = self.checkpointer.restore(
step,
args=ocp.args.Composite(
model_state=ocp.args.StandardRestore(state),
),
)
state = result.model_state
return nnx.merge(graphdef, state)
[docs]
@partial(jax.named_call, name="CheckpointModelLoader.latest_step")
def latest_step(self) -> Optional[int]:
return self.checkpointer.latest_step()
[docs]
@partial(jax.named_call, name="CheckpointModelLoader.block_until_ready")
def block_until_ready(self) -> None:
self.checkpointer.wait_until_finished()
[docs]
@partial(jax.named_call, name="CheckpointModelLoader.close")
def close(self) -> None:
try:
self.checkpointer.wait_until_finished()
finally:
self.checkpointer.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.close()
return False
def __del__(self):
try:
self.close()
except Exception:
pass
__all__ = [
"CheckpointWriter",
"CheckpointLoader",
"CheckpointModelWriter",
"CheckpointModelLoader",
]