jaxdem.writers#

Interface for defining data writers.

This module provides a high-level VTKWriter and CheckpointWriter frontend, a VTKBaseWriter plugin interface, and concrete writers (e.g., VTKSpheresWriter, VTKDomainWriter) for exporting JAX-based simulation snapshots to VTK files.

Classes

VTKBaseWriter()

Abstract base class for writers that output simulation data.

class jaxdem.writers.CheckpointLoader(directory: Path | str = PosixPath('checkpoints'))#

Bases: BaseCheckpointManager

Thin wrapper around Orbax checkpoint restoring for jaxdem.state and jaxdem.system.

load(step: int | None = None) tuple[State, System][source]#

Restore a checkpoint.

Parameters:

step (Optional[int]) –

  • If None, load the latest checkpoint.

  • Otherwise, load the specified step.

Returns:

A tuple containing the restored State and System.

Return type:

Tuple[State, System]

latest_step() int | None[source]#
class jaxdem.writers.CheckpointModelLoader(directory: Path | str = PosixPath('checkpoints'))#

Bases: BaseCheckpointManager

Thin wrapper around Orbax checkpoint restoring for jaxdem.rl.models.Model.

load(step: int | None = None) Model[source]#

Load a model from a given step (or the latest if None).

latest_step() int | None[source]#
class jaxdem.writers.CheckpointModelWriter(directory: Path | str = PosixPath('checkpoints'), max_to_keep: int | None = None, save_every: int = 1, clean: bool = True)#

Bases: BaseCheckpointManager

Thin wrapper around Orbax checkpoint saving for jaxdem.rl.models.Model.

max_to_keep: int | None = None#

Keep the last max_to_keep checkpoints. If None, everything is saved.

save_every: int = 1#

How often to write; writes on every save_every-th call to save().

clean: bool = True#

Whether to clean the directory.

save(model: Model, step: int) None[source]#

Save model at a step: stores model_state and JSON metadata. Assumes model.metadata includes JSON-serializable fields. We add model_type.

class jaxdem.writers.CheckpointWriter(directory: Path | str = PosixPath('checkpoints'), max_to_keep: int | None = None, save_every: int = 1)#

Bases: BaseCheckpointManager

Thin wrapper around Orbax checkpoint saving.

Notes

Custom force functions passed via force_manager_kw are serialized by their fully-qualified module path (e.g. mypackage.forces.trap). Functions defined in the top-level script (__main__) cannot be restored from a different script. A warning is emitted at save time if any force function lives in __main__. To ensure portability, define force functions in an importable module.

max_to_keep: int | None = None#

Keep the last max_to_keep checkpoints. If None, everything is saved.

save_every: int = 1#

How often to write; writes on every save_every-th call to save().

save(state: State, system: System) None[source]#

Save a checkpoint for the provided state/system at a given step.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The current system configuration.

class jaxdem.writers.VTKBaseWriter#

Bases: Factory, ABC

Abstract base class for writers that output simulation data.

Concrete subclasses implement the write method to specify how a given snapshot (jaxdem.State, jaxdem.System pair) is converted into a specific file format.

Example:#

To define a custom VTK writer, inherit from VTKBaseWriter and implement its abstract methods:

>>> @VTKBaseWriter.register("my_custom_vtk_writer")
>>> @dataclass(slots=True)
>>> class MyCustomVTKWriter(VTKBaseWriter):
        ...
classmethod is_active(state: State, system: System) bool[source]#

Check whether this writer has data to write for the given state and system.

abstractmethod classmethod write(state: State, system: System, filename: Path, binary: bool) None[source]#

Write information from a simulation snapshot to a VTK PolyData file.

This abstract method is the core interface for all concrete VTK writers. Implementations should assume that all jax arrays are converted to numpy arrays before write is called.

Parameters:
  • state (State) – The simulation jaxdem.State snapshot to be written.

  • system (System) – The simulation jaxdem.System configuration.

  • filename (Path) – Target path where the VTK file should be saved. The caller guarantees that it exists.

  • binary (bool) – If True, the VTK file is written in binary mode; if False, it is written in ASCII (human-readable) mode.

class jaxdem.writers.VTKDeformableEdgeAdjacenciesWriter#

Bases: VTKBaseWriter

classmethod is_active(state: State, system: System) bool[source]#

Check whether this writer has data to write for the given state and system.

classmethod write(state: State, system: System, filename: Path, binary: bool) None[source]#

Write information from a simulation snapshot to a VTK PolyData file.

This abstract method is the core interface for all concrete VTK writers. Implementations should assume that all jax arrays are converted to numpy arrays before write is called.

Parameters:
  • state (State) – The simulation jaxdem.State snapshot to be written.

  • system (System) – The simulation jaxdem.System configuration.

  • filename (Path) – Target path where the VTK file should be saved. The caller guarantees that it exists.

  • binary (bool) – If True, the VTK file is written in binary mode; if False, it is written in ASCII (human-readable) mode.

class jaxdem.writers.VTKDeformableEdgesWriter#

Bases: VTKBaseWriter

classmethod is_active(state: State, system: System) bool[source]#

Check whether this writer has data to write for the given state and system.

classmethod write(state: State, system: System, filename: Path, binary: bool) None[source]#

Write information from a simulation snapshot to a VTK PolyData file.

This abstract method is the core interface for all concrete VTK writers. Implementations should assume that all jax arrays are converted to numpy arrays before write is called.

Parameters:
  • state (State) – The simulation jaxdem.State snapshot to be written.

  • system (System) – The simulation jaxdem.System configuration.

  • filename (Path) – Target path where the VTK file should be saved. The caller guarantees that it exists.

  • binary (bool) – If True, the VTK file is written in binary mode; if False, it is written in ASCII (human-readable) mode.

class jaxdem.writers.VTKDeformableElementsWriter#

Bases: VTKBaseWriter

classmethod is_active(state: State, system: System) bool[source]#

Check whether this writer has data to write for the given state and system.

classmethod write(state: State, system: System, filename: Path, binary: bool) None[source]#

Write information from a simulation snapshot to a VTK PolyData file.

This abstract method is the core interface for all concrete VTK writers. Implementations should assume that all jax arrays are converted to numpy arrays before write is called.

Parameters:
  • state (State) – The simulation jaxdem.State snapshot to be written.

  • system (System) – The simulation jaxdem.System configuration.

  • filename (Path) – Target path where the VTK file should be saved. The caller guarantees that it exists.

  • binary (bool) – If True, the VTK file is written in binary mode; if False, it is written in ASCII (human-readable) mode.

class jaxdem.writers.VTKDomainWriter#

Bases: VTKBaseWriter

A VTKBaseWriter that writes the simulation domain as a VTK geometric primitive.

The domain is represented as an axis-aligned cuboid (3D) or rectangle (2D), using a vtkCubeSource. If input arrays are 2D, they are padded to 3D as required by VTK.

classmethod write(state: State, system: System, filename: Path, binary: bool) None[source]#

Write information from a simulation snapshot to a VTK PolyData file.

This abstract method is the core interface for all concrete VTK writers. Implementations should assume that all jax arrays are converted to numpy arrays before write is called.

Parameters:
  • state (State) – The simulation jaxdem.State snapshot to be written.

  • system (System) – The simulation jaxdem.System configuration.

  • filename (Path) – Target path where the VTK file should be saved. The caller guarantees that it exists.

  • binary (bool) – If True, the VTK file is written in binary mode; if False, it is written in ASCII (human-readable) mode.

class jaxdem.writers.VTKSpheresWriter#

Bases: VTKBaseWriter

A VTKBaseWriter that writes particle centers as VTK points and attaches per-particle State fields as PointData attributes.

For each particle, its position is written as a point. Relevant per-particle fields (e.g., vel, rad, mass) are exported as arrays. Positions and 2-component vectors are padded to 3D as required by VTK.

classmethod write(state: State, system: System, filename: Path, binary: bool) None[source]#

Write information from a simulation snapshot to a VTK PolyData file.

This abstract method is the core interface for all concrete VTK writers. Implementations should assume that all jax arrays are converted to numpy arrays before write is called.

Parameters:
  • state (State) – The simulation jaxdem.State snapshot to be written.

  • system (System) – The simulation jaxdem.System configuration.

  • filename (Path) – Target path where the VTK file should be saved. The caller guarantees that it exists.

  • binary (bool) – If True, the VTK file is written in binary mode; if False, it is written in ASCII (human-readable) mode.

class jaxdem.writers.VTKWriter(directory: Path = PosixPath('frames'), save_every: int = 1, clean: bool = True, max_workers: int = 4, writers: list[str] = <factory>, binary: bool = True)#

Bases: BaseAsyncWriter

High-level front end for writing simulation data to VTK files.

This class orchestrates the conversion of JAX-based jaxdem.State and jaxdem.System pytrees into VTK files, handling batches, trajectories, and dispatch to registered jaxdem.VTKBaseWriter subclasses.

How leading axes are interpreted#

Let particle positions have shape (..., N, dim), where N is the number of particles and dim is 2 or 3. Define L = state.pos_c.ndim - 2, i.e., the number of leading axes before (N, dim).

  • L == 0 — single snapshot

    The input is one frame. It is written directly into frames/batch_00000000/ (no batching, no trajectory).

  • trajectory=False (default)

    All leading axes are treated as batch axes (not time). If multiple batch axes are present, they are flattened into a single batch axis: (B, N, dim) with B = prod(shape[:L]). Each batch b is written as a single snapshot under its own subdirectory frames/batch_XXXXXXXX/. No trajectory is implied.

    • Example: (B, N, dim) → B separate directories with one frame each.

    • Example: (B1, B2, N, dim) → flatten to (B1*B2, N, dim) and treat as above.

  • trajectory=True

    The axis given by trajectory_axis is swapped to the front (axis 0) and interpreted as time T. Any remaining leading axes are batch axes. If more than one non-time leading axis exists, they are flattened into a single batch axis so the data becomes (T, B, N, dim) with B = prod(other leading axes).

    • If there is only time (L == 1): (T, N, dim) — a single batch

      directory frames/batch_00000000/ contains a time series with T frames.

    • If there is time plus batching (L >= 2): (T, B, N, dim) — each

      batch b gets its own directory frames/batch_XXXXXXXX/ containing a time series (T frames) for that batch.

After these swaps/reshapes, dispatch is: - (N, dim) → single snapshot - (B, N, dim) → batches (no time) - (T, N, dim) → single batch with a trajectory - (T, B, N, dim) → per-batch trajectories

Concrete writers receive per-frame NumPy arrays; leaves in System are sliced/broadcast consistently with the current frame/batch.

writers: list[str]#

A list of strings specifying which registered VTKBaseWriter subclasses should be used for writing. If empty, all available subclasses will be used.

binary: bool = True#

If True, VTK files will be written in binary format. If False, files will be written in ASCII format.

save(state: State, system: System, *, trajectory: bool = False, trajectory_axis: int = 0, batch0: int = 0) None[source]#

Schedule writing of a jaxdem.State / jaxdem.System pair to VTK files.

This public entry point interprets leading axes (batch vs. trajectory), performs any required axis swapping and flattening, and then pushes the data to the background writer queue.

Parameters:
  • state (State) – The simulation jaxdem.State object to be saved.

  • system (System) – The jaxdem.System object corresponding to state.

  • trajectory (bool, optional) – If True, interpret trajectory_axis as time.

  • trajectory_axis (int, optional) – The axis in state/system to treat as the trajectory axis.

  • batch0 (int, optional) – The starting batch index for the input data.

directory: Path#

The root directory where simulation frames will be saved.

save_every: int#

Frequency of saving. A frame is pushed to the queue every save_every calls to the save() method.

clean: bool#

If True, the directory is deleted and recreated upon initialization. Basic safety checks are performed to prevent deleting the current working directory or the system root.

max_workers: int#

The number of background worker threads to use for parallel I/O.

Modules

async_base

Defines the base infrastructure for asynchronous data writing.

checkpoints

Orbax checkpoint writer and a loader.

vtkDeformableParticleWriter

VTK writers for deformable particles.

vtkDomainWriter

VTK writer for domain geometry.

vtkSpheresWriter

VTK writer that exports spheres data.

vtkWriter

Implementation of the high-level VTKWriter frontend.