jaxdem.writers.checkpoints#

Orbax checkpoint writer and a loader.

  • CheckpointWriter: saves checkpoints with preservation/decision policies

  • CheckpointLoader: restores checkpoints (latest or specific step)

Classes

CheckpointLoader([directory])

Thin wrapper around Orbax checkpoint restoring for jaxdem.state and jaxdem.system.

CheckpointModelLoader([directory])

Thin wrapper around Orbax checkpoint restoring for jaxdem.rl.models.Model.

CheckpointModelWriter([directory, ...])

Thin wrapper around Orbax checkpoint saving for jaxdem.rl.models.Model.

CheckpointWriter([directory, max_to_keep, ...])

Thin wrapper around Orbax checkpoint saving.

class jaxdem.writers.checkpoints.CheckpointWriter(directory: Path | str = PosixPath('checkpoints'), max_to_keep: int | None = None, save_every: int = 1)[source]#

Bases: object

Thin wrapper around Orbax checkpoint saving.

directory: Path | str#

The base directory where checkpoints will be saved.

max_to_keep: int | None#

Keep the last max_to_keep checkpoints. If None, everything is save.

save_every: int#

How often to write; writes on every save_every-th call to save().

checkpointer: CheckpointManager#

Orbax checkpoint manager for saving the checkpoints.

save(state: State, system: System) None[source][source]#

Save a checkpoint for the provided state/system at a given step.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The current system configuration.

block_until_ready()[source][source]#

Wait for the checkpointer to finish.

close() None[source][source]#

Wait for the checkpointer to finish and close it.

class jaxdem.writers.checkpoints.CheckpointLoader(directory: Path = PosixPath('checkpoints'))[source]#

Bases: object

Thin wrapper around Orbax checkpoint restoring for jaxdem.state and jaxdem.system.

directory: Path#

The base directory where checkpoints will be saved.

checkpointer: CheckpointManager#

Orbax checkpoint manager for saving the checkpoints.

load(step: int | None = None) Tuple[State, System][source][source]#

Restore a checkpoint.

Parameters:

step (Optional[int]) –

  • If None, load the latest checkpoint.

  • Otherwise, load the specified step.

Returns:

A tuple containing the restored State and System.

Return type:

Tuple[State, System]

block_until_ready()[source][source]#
close() None[source][source]#
class jaxdem.writers.checkpoints.CheckpointModelWriter(directory: Path | str = PosixPath('checkpoints'), max_to_keep: int | None = None, save_every: int = 1, clean: bool = True)[source]#

Bases: object

Thin wrapper around Orbax checkpoint saving for jaxdem.rl.models.Model.

directory: Path | str#

The base directory where checkpoints will be saved.

max_to_keep: int | None#

Keep the last max_to_keep checkpoints. If None, everything is save.

save_every: int#

How often to write; writes on every save_every-th call to save().

checkpointer: CheckpointManager#

Orbax checkpoint manager for saving the checkpoints.

clean: bool#

Wether to clean the directory

save(model: Model, step: int) None[source][source]#

Save model at a step: stores model_state and JSON metadata. Assumes model.metadata includes JSON-serializable fields. We add model_type.

block_until_ready() None[source][source]#
close() None[source][source]#
class jaxdem.writers.checkpoints.CheckpointModelLoader(directory: Path = PosixPath('checkpoints'))[source]#

Bases: object

Thin wrapper around Orbax checkpoint restoring for jaxdem.rl.models.Model.

directory: Path#

The base directory where checkpoints will be saved.

checkpointer: CheckpointManager#

Orbax checkpoint manager for saving the checkpoints.

load(step: int | None = None)[source][source]#

Load a model from a given step (or the latest if None).

latest_step() int | None[source][source]#
block_until_ready() None[source][source]#
close() None[source][source]#