jaxdem.writers#
Interface for defining data writers.
This module provides a high-level VTKWriter and CheckpointWriter frontend, a VTKBaseWriter plugin interface, and concrete writers (e.g., VTKSpheresWriter, VTKDomainWriter) for exporting JAX-based simulation snapshots to VTK files.
Classes
Abstract base class for writers that output simulation data. |
- class jaxdem.writers.CheckpointLoader(directory: Path | str = PosixPath('checkpoints'))#
Bases:
BaseCheckpointManagerThin wrapper around Orbax checkpoint restoring for jaxdem.state and jaxdem.system.
- class jaxdem.writers.CheckpointModelLoader(directory: Path | str = PosixPath('checkpoints'))#
Bases:
BaseCheckpointManagerThin wrapper around Orbax checkpoint restoring for jaxdem.rl.models.Model.
- class jaxdem.writers.CheckpointModelWriter(directory: Path | str = PosixPath('checkpoints'), max_to_keep: int | None = None, save_every: int = 1, clean: bool = True)#
Bases:
BaseCheckpointManagerThin wrapper around Orbax checkpoint saving for jaxdem.rl.models.Model.
- max_to_keep: int | None = None#
Keep the last max_to_keep checkpoints. If None, everything is saved.
- clean: bool = True#
Whether to clean the directory.
- class jaxdem.writers.CheckpointWriter(directory: Path | str = PosixPath('checkpoints'), max_to_keep: int | None = None, save_every: int = 1)#
Bases:
BaseCheckpointManagerThin wrapper around Orbax checkpoint saving.
Notes
Custom force functions passed via
force_manager_kware serialized by their fully-qualified module path (e.g.mypackage.forces.trap). Functions defined in the top-level script (__main__) cannot be restored from a different script. A warning is emitted at save time if any force function lives in__main__. To ensure portability, define force functions in an importable module.- max_to_keep: int | None = None#
Keep the last max_to_keep checkpoints. If None, everything is saved.
- class jaxdem.writers.VTKBaseWriter#
Bases:
Factory,ABCAbstract base class for writers that output simulation data.
Concrete subclasses implement the write method to specify how a given snapshot (
jaxdem.State,jaxdem.Systempair) is converted into a specific file format.Example:#
To define a custom VTK writer, inherit from VTKBaseWriter and implement its abstract methods:
>>> @VTKBaseWriter.register("my_custom_vtk_writer") >>> @dataclass(slots=True) >>> class MyCustomVTKWriter(VTKBaseWriter): ...
- classmethod is_active(state: State, system: System) bool[source]#
Check whether this writer has data to write for the given state and system.
- abstractmethod classmethod write(state: State, system: System, filename: Path, binary: bool) None[source]#
Write information from a simulation snapshot to a VTK PolyData file.
This abstract method is the core interface for all concrete VTK writers. Implementations should assume that all jax arrays are converted to numpy arrays before write is called.
- Parameters:
state (State) – The simulation
jaxdem.Statesnapshot to be written.system (System) – The simulation
jaxdem.Systemconfiguration.filename (Path) – Target path where the VTK file should be saved. The caller guarantees that it exists.
binary (bool) – If True, the VTK file is written in binary mode; if False, it is written in ASCII (human-readable) mode.
- class jaxdem.writers.VTKDeformableEdgeAdjacenciesWriter#
Bases:
VTKBaseWriter- classmethod is_active(state: State, system: System) bool[source]#
Check whether this writer has data to write for the given state and system.
- classmethod write(state: State, system: System, filename: Path, binary: bool) None[source]#
Write information from a simulation snapshot to a VTK PolyData file.
This abstract method is the core interface for all concrete VTK writers. Implementations should assume that all jax arrays are converted to numpy arrays before write is called.
- Parameters:
state (State) – The simulation
jaxdem.Statesnapshot to be written.system (System) – The simulation
jaxdem.Systemconfiguration.filename (Path) – Target path where the VTK file should be saved. The caller guarantees that it exists.
binary (bool) – If True, the VTK file is written in binary mode; if False, it is written in ASCII (human-readable) mode.
- class jaxdem.writers.VTKDeformableEdgesWriter#
Bases:
VTKBaseWriter- classmethod is_active(state: State, system: System) bool[source]#
Check whether this writer has data to write for the given state and system.
- classmethod write(state: State, system: System, filename: Path, binary: bool) None[source]#
Write information from a simulation snapshot to a VTK PolyData file.
This abstract method is the core interface for all concrete VTK writers. Implementations should assume that all jax arrays are converted to numpy arrays before write is called.
- Parameters:
state (State) – The simulation
jaxdem.Statesnapshot to be written.system (System) – The simulation
jaxdem.Systemconfiguration.filename (Path) – Target path where the VTK file should be saved. The caller guarantees that it exists.
binary (bool) – If True, the VTK file is written in binary mode; if False, it is written in ASCII (human-readable) mode.
- class jaxdem.writers.VTKDeformableElementsWriter#
Bases:
VTKBaseWriter- classmethod is_active(state: State, system: System) bool[source]#
Check whether this writer has data to write for the given state and system.
- classmethod write(state: State, system: System, filename: Path, binary: bool) None[source]#
Write information from a simulation snapshot to a VTK PolyData file.
This abstract method is the core interface for all concrete VTK writers. Implementations should assume that all jax arrays are converted to numpy arrays before write is called.
- Parameters:
state (State) – The simulation
jaxdem.Statesnapshot to be written.system (System) – The simulation
jaxdem.Systemconfiguration.filename (Path) – Target path where the VTK file should be saved. The caller guarantees that it exists.
binary (bool) – If True, the VTK file is written in binary mode; if False, it is written in ASCII (human-readable) mode.
- class jaxdem.writers.VTKDomainWriter#
Bases:
VTKBaseWriterA
VTKBaseWriterthat writes the simulation domain as a VTK geometric primitive.The domain is represented as an axis-aligned cuboid (3D) or rectangle (2D), using a
vtkCubeSource. If input arrays are 2D, they are padded to 3D as required by VTK.- classmethod write(state: State, system: System, filename: Path, binary: bool) None[source]#
Write information from a simulation snapshot to a VTK PolyData file.
This abstract method is the core interface for all concrete VTK writers. Implementations should assume that all jax arrays are converted to numpy arrays before write is called.
- Parameters:
state (State) – The simulation
jaxdem.Statesnapshot to be written.system (System) – The simulation
jaxdem.Systemconfiguration.filename (Path) – Target path where the VTK file should be saved. The caller guarantees that it exists.
binary (bool) – If True, the VTK file is written in binary mode; if False, it is written in ASCII (human-readable) mode.
- class jaxdem.writers.VTKSpheresWriter#
Bases:
VTKBaseWriterA
VTKBaseWriterthat writes particle centers as VTK points and attaches per-particleStatefields asPointDataattributes.For each particle, its position is written as a point. Relevant per-particle fields (e.g.,
vel,rad,mass) are exported as arrays. Positions and 2-component vectors are padded to 3D as required by VTK.- classmethod write(state: State, system: System, filename: Path, binary: bool) None[source]#
Write information from a simulation snapshot to a VTK PolyData file.
This abstract method is the core interface for all concrete VTK writers. Implementations should assume that all jax arrays are converted to numpy arrays before write is called.
- Parameters:
state (State) – The simulation
jaxdem.Statesnapshot to be written.system (System) – The simulation
jaxdem.Systemconfiguration.filename (Path) – Target path where the VTK file should be saved. The caller guarantees that it exists.
binary (bool) – If True, the VTK file is written in binary mode; if False, it is written in ASCII (human-readable) mode.
- class jaxdem.writers.VTKWriter(directory: Path = PosixPath('frames'), save_every: int = 1, clean: bool = True, max_workers: int = 4, writers: list[str] = <factory>, binary: bool = True)#
Bases:
BaseAsyncWriterHigh-level front end for writing simulation data to VTK files.
This class orchestrates the conversion of JAX-based
jaxdem.Stateandjaxdem.Systempytrees into VTK files, handling batches, trajectories, and dispatch to registeredjaxdem.VTKBaseWritersubclasses.How leading axes are interpreted#
Let particle positions have shape
(..., N, dim), whereNis the number of particles anddimis 2 or 3. DefineL = state.pos_c.ndim - 2, i.e., the number of leading axes before(N, dim).L == 0— single snapshotThe input is one frame. It is written directly into
frames/batch_00000000/(no batching, no trajectory).
trajectory=False(default)All leading axes are treated as batch axes (not time). If multiple batch axes are present, they are flattened into a single batch axis:
(B, N, dim)withB = prod(shape[:L]). Each batchbis written as a single snapshot under its own subdirectoryframes/batch_XXXXXXXX/. No trajectory is implied.Example:
(B, N, dim)→ B separate directories with one frame each.Example:
(B1, B2, N, dim)→ flatten to(B1*B2, N, dim)and treat as above.
trajectory=TrueThe axis given by
trajectory_axisis swapped to the front (axis 0) and interpreted as timeT. Any remaining leading axes are batch axes. If more than one non-time leading axis exists, they are flattened into a single batch axis so the data becomes(T, B, N, dim)withB = prod(other leading axes).- If there is only time (
L == 1):(T, N, dim)— a single batch directory
frames/batch_00000000/contains a time series withTframes.
- If there is only time (
- If there is time plus batching (
L >= 2):(T, B, N, dim)— each batch
bgets its own directoryframes/batch_XXXXXXXX/containing a time series (Tframes) for that batch.
- If there is time plus batching (
After these swaps/reshapes, dispatch is: -
(N, dim)→ single snapshot -(B, N, dim)→ batches (no time) -(T, N, dim)→ single batch with a trajectory -(T, B, N, dim)→ per-batch trajectoriesConcrete writers receive per-frame NumPy arrays; leaves in
Systemare sliced/broadcast consistently with the current frame/batch.- writers: list[str]#
A list of strings specifying which registered
VTKBaseWritersubclasses should be used for writing. If empty, all available subclasses will be used.
- binary: bool = True#
If True, VTK files will be written in binary format. If False, files will be written in ASCII format.
- save(state: State, system: System, *, trajectory: bool = False, trajectory_axis: int = 0, batch0: int = 0) None[source]#
Schedule writing of a
jaxdem.State/jaxdem.Systempair to VTK files.This public entry point interprets leading axes (batch vs. trajectory), performs any required axis swapping and flattening, and then pushes the data to the background writer queue.
- Parameters:
state (State) – The simulation
jaxdem.Stateobject to be saved.system (System) – The
jaxdem.Systemobject corresponding to state.trajectory (bool, optional) – If
True, interprettrajectory_axisas time.trajectory_axis (int, optional) – The axis in state/system to treat as the trajectory axis.
batch0 (int, optional) – The starting batch index for the input data.
- directory: Path#
The root directory where simulation frames will be saved.
- save_every: int#
Frequency of saving. A frame is pushed to the queue every save_every calls to the
save()method.
- clean: bool#
If True, the directory is deleted and recreated upon initialization. Basic safety checks are performed to prevent deleting the current working directory or the system root.
- max_workers: int#
The number of background worker threads to use for parallel I/O.
Modules
Defines the base infrastructure for asynchronous data writing. |
|
Orbax checkpoint writer and a loader. |
|
VTK writers for deformable particles. |
|
VTK writer for domain geometry. |
|
VTK writer that exports spheres data. |
|
Implementation of the high-level VTKWriter frontend. |