jaxdem.writers.checkpoints#

Orbax checkpoint writer and a loader.

  • CheckpointWriter: saves checkpoints with preservation/decision policies

  • CheckpointLoader: restores checkpoints (latest or specific step)

Classes

BaseCheckpointManager([directory])

Base class providing context management and boilerplate for Orbax checkpointers.

CheckpointLoader([directory])

Thin wrapper around Orbax checkpoint restoring for jaxdem.state and jaxdem.system.

CheckpointModelLoader([directory])

Thin wrapper around Orbax checkpoint restoring for jaxdem.rl.models.Model.

CheckpointModelWriter([directory, ...])

Thin wrapper around Orbax checkpoint saving for jaxdem.rl.models.Model.

CheckpointWriter([directory, max_to_keep, ...])

Thin wrapper around Orbax checkpoint saving.

class jaxdem.writers.checkpoints.CheckpointLoader(directory: Path | str = PosixPath('checkpoints'))#

Bases: BaseCheckpointManager

Thin wrapper around Orbax checkpoint restoring for jaxdem.state and jaxdem.system.

load(step: int | None = None) tuple[State, System][source]#

Restore a checkpoint.

Parameters:

step (Optional[int]) –

  • If None, load the latest checkpoint.

  • Otherwise, load the specified step.

Returns:

A tuple containing the restored State and System.

Return type:

Tuple[State, System]

latest_step() int | None[source]#
class jaxdem.writers.checkpoints.CheckpointModelLoader(directory: Path | str = PosixPath('checkpoints'))#

Bases: BaseCheckpointManager

Thin wrapper around Orbax checkpoint restoring for jaxdem.rl.models.Model.

load(step: int | None = None) Model[source]#

Load a model from a given step (or the latest if None).

latest_step() int | None[source]#
class jaxdem.writers.checkpoints.CheckpointModelWriter(directory: Path | str = PosixPath('checkpoints'), max_to_keep: int | None = None, save_every: int = 1, clean: bool = True)#

Bases: BaseCheckpointManager

Thin wrapper around Orbax checkpoint saving for jaxdem.rl.models.Model.

max_to_keep: int | None = None#

Keep the last max_to_keep checkpoints. If None, everything is saved.

save_every: int = 1#

How often to write; writes on every save_every-th call to save().

clean: bool = True#

Whether to clean the directory.

save(model: Model, step: int) None[source]#

Save model at a step: stores model_state and JSON metadata. Assumes model.metadata includes JSON-serializable fields. We add model_type.

class jaxdem.writers.checkpoints.CheckpointWriter(directory: Path | str = PosixPath('checkpoints'), max_to_keep: int | None = None, save_every: int = 1)#

Bases: BaseCheckpointManager

Thin wrapper around Orbax checkpoint saving.

Notes

Custom force functions passed via force_manager_kw are serialized by their fully-qualified module path (e.g. mypackage.forces.trap). Functions defined in the top-level script (__main__) cannot be restored from a different script. A warning is emitted at save time if any force function lives in __main__. To ensure portability, define force functions in an importable module.

max_to_keep: int | None = None#

Keep the last max_to_keep checkpoints. If None, everything is saved.

save_every: int = 1#

How often to write; writes on every save_every-th call to save().

save(state: State, system: System) None[source]#

Save a checkpoint for the provided state/system at a given step.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The current system configuration.