Checkpoint Save and Load.#

This guide introduces JaxDEM checkpointing utilities:

Checkpoints are useful for long simulations, reproducibility, and restarting from intermediate steps.

import tempfile
from pathlib import Path
import jax.numpy as jnp
import jaxdem as jdem

Saving simulation checkpoints#

We create a small simulation, run it in chunks, and save snapshots.

state = jdem.State.create(pos=jnp.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]))
system = jdem.System.create(state.shape, dt=1e-3)

tmp_dir = Path(tempfile.gettempdir()) / "simulation"
with jdem.CheckpointWriter(directory=tmp_dir, max_to_keep=2) as writer:
    writer.save(state, system)  # step 0

    state, system = system.step(state, system, n=5)
    writer.save(state, system)  # step 5

    state, system = system.step(state, system, n=5)
    writer.save(state, system)  # step 10

Loading latest and specific checkpoints#

load() returns (state, system). The current latest step can be queried using latest_step().

with jdem.CheckpointLoader(directory=tmp_dir) as loader:
    print("Available steps:", loader.checkpointer.all_steps())
    print("Latest step:", loader.latest_step())

    state_latest, system_latest = loader.load()
    print("Loaded latest step_count:", int(system_latest.step_count))

    state_step_5, system_step_5 = loader.load(step=5)
    print("Loaded step_count=5:", int(system_step_5.step_count))
    print("State shape at step 5:", state_step_5.pos.shape)
Available steps: [5, 10]
Latest step: 10
Loaded latest step_count: 10
Loaded step_count=5: 5
State shape at step 5: (2, 3)

Resource management note: CheckpointWriter and CheckpointLoader can be used either with a context manager (with ... as ...) or manually without with. Using with is recommended because it automatically waits for pending async writes and closes resources on exit.

If you use them manually, remember:

writer = jdem.CheckpointWriter(directory=sim_checkpoint_dir, max_to_keep=2)
writer.save(state, system)   # async
writer.block_until_ready()   # ensure writes are finished
writer.close()               # release resources

Checkpoint saving is asynchronous, so call block_until_ready() before program exit (and before close() when managing manually) to guarantee files are fully written.

Bonded-force model checkpointing#

Checkpointing also supports systems with BondedForceModel instances, such as DeformableParticleModel.

Note

Bonded-force models register their own force and energy functions with the ForceManager automatically during create(). These internal functions are not serialized like user-supplied custom forces — they are reconstructed from the bonded model type and parameters at load time. You only need to worry about serialization for custom force functions that you add via force_manager_kw.

vertices_2d = jnp.array(
    [[0.0, 0.0], [1.0, 0.0], [1.0, 1.0], [0.0, 1.0]],
    dtype=float,
)
elements_2d = jnp.array([[0, 1], [1, 2], [2, 3], [3, 0]], dtype=int)
adjacency_2d = jnp.array([[0, 1], [1, 2], [2, 3], [3, 0]], dtype=int)

bonded_model = jdem.BondedForceModel.create(
    "deformableparticlemodel",
    vertices=vertices_2d,
    elements=elements_2d,
    edges=elements_2d,
    element_adjacency=adjacency_2d,
    em=1.0,
)

state_bonded = jdem.State.create(pos=vertices_2d)
system_bonded = jdem.System.create(
    state_bonded.shape,
    bonded_force_model=bonded_model,
)

with jdem.CheckpointWriter(directory=tmp_dir) as writer:
    writer.save(state_bonded, system_bonded)
    writer.block_until_ready()

with jdem.CheckpointLoader(directory=tmp_dir) as loader:
    _, system_restored = loader.load()
    print(
        "Restored bonded model:",
        (
            None
            if system_restored.bonded_force_model is None
            else system_restored.bonded_force_model.type_name
        ),
    )
Restored bonded model: deformableparticlemodel

Custom force functions and checkpointing#

Custom force functions passed via force_manager_kw are serialized by their fully-qualified module path (e.g. mypackage.forces.harmonic_trap).

Warning

Functions defined in the top-level script (__main__) cannot be restored from a different script. A warning is emitted at save time when this situation is detected. If a function cannot be resolved during loading, it is silently skipped and a warning is logged.

To ensure that checkpoints are portable across scripts, define your custom force (and energy) functions in a separate importable module:

# my_forces.py  <-- importable module
import jax
import jax.numpy as jnp

def harmonic_trap(pos, state, system):
    k = 1.0
    return -k * pos, jnp.zeros_like(state.torque)

def harmonic_trap_energy(pos, state, system):
    k = 1.0
    return 0.5 * k * jnp.sum(pos ** 2, axis=-1)

Then use them in your simulation:

from my_forces import harmonic_trap, harmonic_trap_energy

system = jdem.System.create(
    state.shape,
    force_manager_kw=dict(
        force_functions=[(harmonic_trap, harmonic_trap_energy)],
    ),
)

Checkpoints saved this way can be loaded from any script that has my_forces on its Python path.

Total running time of the script: (0 minutes 5.515 seconds)