Note
Go to the end to download the full example code.
Checkpoint Save and Load.#
This guide introduces JaxDEM checkpointing utilities:
Checkpoints are useful for long simulations, reproducibility, and restarting from intermediate steps.
import tempfile
from pathlib import Path
import jax.numpy as jnp
import jaxdem as jdem
Saving simulation checkpoints#
We create a small simulation, run it in chunks, and save snapshots.
state = jdem.State.create(pos=jnp.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]))
system = jdem.System.create(state.shape, dt=1e-3)
tmp_dir = Path(tempfile.gettempdir()) / "simulation"
with jdem.CheckpointWriter(directory=tmp_dir, max_to_keep=2) as writer:
writer.save(state, system) # step 0
state, system = system.step(state, system, n=5)
writer.save(state, system) # step 5
state, system = system.step(state, system, n=5)
writer.save(state, system) # step 10
Loading latest and specific checkpoints#
load() returns (state, system). The current latest step can be queried
using latest_step().
with jdem.CheckpointLoader(directory=tmp_dir) as loader:
print("Available steps:", loader.checkpointer.all_steps())
print("Latest step:", loader.latest_step())
state_latest, system_latest = loader.load()
print("Loaded latest step_count:", int(system_latest.step_count))
state_step_5, system_step_5 = loader.load(step=5)
print("Loaded step_count=5:", int(system_step_5.step_count))
print("State shape at step 5:", state_step_5.pos.shape)
Available steps: [5, 10]
Latest step: 10
Loaded latest step_count: 10
Loaded step_count=5: 5
State shape at step 5: (2, 3)
Resource management note:
CheckpointWriter and CheckpointLoader
can be used either with a context manager (with ... as ...) or manually without with.
Using with is recommended because it automatically waits for pending
async writes and closes resources on exit.
If you use them manually, remember:
writer = jdem.CheckpointWriter(directory=sim_checkpoint_dir, max_to_keep=2)
writer.save(state, system) # async
writer.block_until_ready() # ensure writes are finished
writer.close() # release resources
Checkpoint saving is asynchronous, so call block_until_ready() before
program exit (and before close() when managing manually) to guarantee
files are fully written.
Bonded-force model checkpointing#
Checkpointing also supports systems with
BondedForceModel instances, such as
DeformableParticleModel.
Note
Bonded-force models register their own force and energy functions with
the ForceManager automatically
during create(). These internal functions
are not serialized like user-supplied custom forces — they are
reconstructed from the bonded model type and parameters at load time.
You only need to worry about serialization for custom force functions
that you add via force_manager_kw.
vertices_2d = jnp.array(
[[0.0, 0.0], [1.0, 0.0], [1.0, 1.0], [0.0, 1.0]],
dtype=float,
)
elements_2d = jnp.array([[0, 1], [1, 2], [2, 3], [3, 0]], dtype=int)
adjacency_2d = jnp.array([[0, 1], [1, 2], [2, 3], [3, 0]], dtype=int)
bonded_model = jdem.BondedForceModel.create(
"deformableparticlemodel",
vertices=vertices_2d,
elements=elements_2d,
edges=elements_2d,
element_adjacency=adjacency_2d,
em=1.0,
)
state_bonded = jdem.State.create(pos=vertices_2d)
system_bonded = jdem.System.create(
state_bonded.shape,
bonded_force_model=bonded_model,
)
with jdem.CheckpointWriter(directory=tmp_dir) as writer:
writer.save(state_bonded, system_bonded)
writer.block_until_ready()
with jdem.CheckpointLoader(directory=tmp_dir) as loader:
_, system_restored = loader.load()
print(
"Restored bonded model:",
(
None
if system_restored.bonded_force_model is None
else system_restored.bonded_force_model.type_name
),
)
Restored bonded model: deformableparticlemodel
Custom force functions and checkpointing#
Custom force functions passed via force_manager_kw are serialized by
their fully-qualified module path (e.g. mypackage.forces.harmonic_trap).
Warning
Functions defined in the top-level script (__main__) cannot be
restored from a different script. A warning is emitted at save time when
this situation is detected. If a function cannot be resolved during
loading, it is silently skipped and a warning is logged.
To ensure that checkpoints are portable across scripts, define your custom force (and energy) functions in a separate importable module:
# my_forces.py <-- importable module
import jax
import jax.numpy as jnp
def harmonic_trap(pos, state, system):
k = 1.0
return -k * pos, jnp.zeros_like(state.torque)
def harmonic_trap_energy(pos, state, system):
k = 1.0
return 0.5 * k * jnp.sum(pos ** 2, axis=-1)
Then use them in your simulation:
from my_forces import harmonic_trap, harmonic_trap_energy
system = jdem.System.create(
state.shape,
force_manager_kw=dict(
force_functions=[(harmonic_trap, harmonic_trap_energy)],
),
)
Checkpoints saved this way can be loaded from any script that has
my_forces on its Python path.
Total running time of the script: (0 minutes 5.515 seconds)