jaxdem.writer#

Interface for defining data writers.

Classes

DomainWriter()

A VTKBaseWriter implementation that writes the simulation domain as a VTK geometric primitive.

SpheresWriter()

A VTKBaseWriter implementation that writes particle centers as VTK points and attaches per-particle State fields as PointData attributes.

VTKBaseWriter()

Abstract base class for writers that output simulation data.

VTKWriter([writers, directory, binary, ...])

High-level front-end for writing simulation data to VTK files.

class jaxdem.writer.VTKBaseWriter[source][source]#

Bases: Factory[VTKBaseWriter], ABC

Abstract base class for writers that output simulation data.

Concrete subclasses implement the write method to specify how a given snapshot (State, System pair) is converted into a specific file format.

Notes

abstractmethod classmethod write(state: State, system: System, counter: int, directory: Path | str, binary: bool) int[source][source]#

Writes a simulation snapshot to a VTK PolyData file.

This abstract method is the core interface for all concrete VTK writers. Implementations should convert the provided JAX-based state and system data into VTK data structures and write them to a file.

Parameters:
  • state (State) – The simulation State snapshot to be written.

  • system (System) – The simulation System configuration.

  • counter (int) – A global, monotonically increasing integer identifier to be embedded in the file name (e.g., spheres_00000042.vtp). This ensures unique file names.

  • directory (pathlib.Path or str) – The target directory where the VTK file should be saved.

  • binary (bool) – If True, the VTK file should be written in binary mode. If False, it should be written in ASCII mode (human-readable).

Returns:

The counter value counter + 1.

Return type:

int

Raises:

NotImplementedError – This is an abstract method and must be implemented by subclasses.

classmethod create(key: str, /, **kw: Any) T[source]#

Creates and returns an instance of a registered subclass.

This method looks up the subclass associated with the given key in the factory’s registry and then calls its constructor with the provided arguments.

Parameters:
  • key (str) – The registration key of the subclass to be created.

  • **kw (Any) – Arbitrary keyword arguments to be passed directly to the constructor of the registered subclass.

Returns:

An instance of the registered subclass.

Return type:

T

Raises:
  • KeyError – If the provided key is not found in the factory’s registry.

  • TypeError – If the provided **kw arguments do not match the signature of the registered subclass’s constructor.

Example

Given Foo factory and Bar registered:

>>> bar_instance = Foo.create("bar", value=42)
>>> print(bar_instance)
Bar(value=42)
classmethod register(key: str | None = None) Callable[[Type[T]], Type[T]][source]#

Registers a subclass with the factory’s registry.

This method returns a decorator that can be applied to a class to register it under a specific key.

Parameters:

key (str or None, optional) – The string key under which to register the subclass. If None, the lowercase name of the subclass itself will be used as the key.

Returns:

A decorator function that takes a class and registers it, returning the class unchanged.

Return type:

Callable[[Type[T]], Type[T]]

Raises:

ValueError – If the provided key (or the default class name) is already registered in the factory’s registry.

Example

Register a class named “MyComponent” under the key “mycomp”:

>>> @MyFactory.register("mycomp")
>>> class MyComponent:
>>>     ...

Register a class named “DefaultComponent” using its own name as the key:

>>> @MyFactory.register()
>>> class DefaultComponent:
>>>     ...
class jaxdem.writer.VTKWriter(writers: List[str] | None = None, directory: str | Path = 'frames', binary: bool = True, clean: bool = True, _counter: int = 0)[source][source]#

Bases: object

High-level front-end for writing simulation data to VTK files.

This class orchestrates the process of converting JAX-based State and System PyTrees into VTK files, handling batching, trajectories, and dispatching to concrete VTKBaseWriter subclasses.

How leading axes are interpreted#

Given state.pos.shape == (…, N, dim) where N is the particle index and dim is the spatial dimension (2 or 3). Let L be the number of remaining leading axes (i.e., L = state.pos.ndim - 2).

  1. If L == 0: The input represents a single snapshot. All writers directly process it.

  2. If L >= 1 and trajectory is False (default behavior of save()): - Axis 0 is treated as a batch dimension. - Axes 1 through L-1 are treated as trajectory dimensions within each batch. This means each slice along axis 0 (state.pos[b, …]) is considered a separate batch. Each batch is then processed recursively, with its remaining leading axes treated as a trajectory. Separate subdirectories (e.g., batch_00000000/) are created for each batch.

  3. If L >= 1 and trajectory is True: - All leading axes (from axis 0 up to L-1) are treated as trajectory dimensions. This is suitable for cases like “trajectory of trajectories” (e.g., from Monte Carlo runs) or when the primary leading dimension is explicitly time.

Inside each batch directory (or the main directory for non-batched trajectories), every trajectory step becomes one or more VTK files per concrete writer (e.g., spheres_00000042.vtp, domain_00000042.vtp).

Requirements on system#

The System object may share leading axes with state or be broadcastable (e.g., a scalar dt for all particles/batches/frames). During the recursive processing, every array leaf of system that has a length matching the current leading axis (lead) is sliced together with the corresponding state slice. This ensures that each individual writer receives matching per-snapshot State and System objects.

Notes

  • All I/O operations (file writing) are executed in a single ThreadPoolExecutor managed by the VTKWriter instance, allowing for concurrent file writes.

  • A global counter (_counter) is incremented before a snapshot is submitted to the thread pool. This guarantees unique file names for all output files, even when threads finish out of order due to background execution.

  • VTKWriter itself is not a JAX PyTree (it’s a standard Python dataclass) and therefore never appears inside jax.jit or jax.vmap transforms; it operates purely on the Python side.

writers: List[str] | None#

A list of strings specifying which registered VTKBaseWriter subclasses should be used for writing. If None, all available VTKBaseWriter subclasses will be used.

directory: str | Path#

The base directory where output VTK files will be saved. Subdirectories might be created within this path for batched outputs. Defaults to “frames”.

binary: bool#

If True, VTK files will be written in binary format. If False, files will be written in ASCII format. Defaults to True.

clean: bool#

If True, the directory will be completely emptied before any files are written. Defaults to True. This is useful for starting a fresh set of output frames.

save(state: State, system: System, *, trajectory: bool = False) int[source][source]#

Schedules the writing of state / system to VTK files.

This is the main public method to trigger saving data. It handles the interpretation of leading axes (batch vs. trajectory) and dispatches the write jobs to a background thread pool. The method blocks until all writing operations are completed and files are on disk.

Parameters:
  • state (State) – The simulation State object to be saved.

  • system (System) – The simulation System object corresponding to the state. It should be consistent in leading dimensions with state.

  • trajectory (bool, optional) – TO DO: EXPLAIN

Returns:

The new value of the global counter after all snapshots (including all batches and trajectory steps) have been written. This counter represents the total number of frames written so far by this writer instance.

Return type:

int

class jaxdem.writer.SpheresWriter[source][source]#

Bases: VTKBaseWriter

A VTKBaseWriter implementation that writes particle centers as VTK points and attaches per-particle State fields as PointData attributes.

For each particle, its position is treated as a point. Other relevant per-particle fields from the State object (like vel, rad, mass, etc.) are added as attributes to these points in the VTK file.

Notes

  • Particle positions are padded to 3D if they are originally 2D, as required by VTK.

  • Only 1D scalar fields (like rad, mass) and 2D/3D vector fields (like vel, accel) are included as PointData. Higher-rank fields or non-array fields are ignored.

  • Boolean arrays (e.g., fixed) are converted to int8 before being passed to VTK.

classmethod create(key: str, /, **kw: Any) T[source]#

Creates and returns an instance of a registered subclass.

This method looks up the subclass associated with the given key in the factory’s registry and then calls its constructor with the provided arguments.

Parameters:
  • key (str) – The registration key of the subclass to be created.

  • **kw (Any) – Arbitrary keyword arguments to be passed directly to the constructor of the registered subclass.

Returns:

An instance of the registered subclass.

Return type:

T

Raises:
  • KeyError – If the provided key is not found in the factory’s registry.

  • TypeError – If the provided **kw arguments do not match the signature of the registered subclass’s constructor.

Example

Given Foo factory and Bar registered:

>>> bar_instance = Foo.create("bar", value=42)
>>> print(bar_instance)
Bar(value=42)
classmethod register(key: str | None = None) Callable[[Type[T]], Type[T]][source]#

Registers a subclass with the factory’s registry.

This method returns a decorator that can be applied to a class to register it under a specific key.

Parameters:

key (str or None, optional) – The string key under which to register the subclass. If None, the lowercase name of the subclass itself will be used as the key.

Returns:

A decorator function that takes a class and registers it, returning the class unchanged.

Return type:

Callable[[Type[T]], Type[T]]

Raises:

ValueError – If the provided key (or the default class name) is already registered in the factory’s registry.

Example

Register a class named “MyComponent” under the key “mycomp”:

>>> @MyFactory.register("mycomp")
>>> class MyComponent:
>>>     ...

Register a class named “DefaultComponent” using its own name as the key:

>>> @MyFactory.register()
>>> class DefaultComponent:
>>>     ...
classmethod write(state: State, system: System, counter: int, directory: Path | str, binary: bool) int[source][source]#

Writes particle data from a single snapshot to a VTK PolyData (.vtp) file.

The file will contain a set of points representing particle centers, and each particle’s relevant properties from the State will be saved as point attributes.

Parameters:
  • state (State) – The simulation state snapshot (NumPy-converted, not JAX arrays).

  • system (System) – The simulation system configuration (NumPy-converted). (Note: Not directly used by SpheresWriter, but required by base signature.)

  • counter (int) – The unique integer identifier for this snapshot.

  • directory (pathlib.Path or str) – The target directory for the output file.

  • binary (bool) – If True, writes in binary mode; False for ASCII.

Returns:

The incremented counter (counter + 1).

Return type:

int

class jaxdem.writer.DomainWriter[source][source]#

Bases: VTKBaseWriter

A VTKBaseWriter implementation that writes the simulation domain as a VTK geometric primitive.

The domain is represented as an axis-aligned cuboid (for 3D simulations) or a rectangle (for 2D simulations). Its position is determined by system.domain.anchor and system.domain.box_size.

Notes

  • The domain’s dimensions from system.domain.box_size are automatically padded to 3D if the simulation is 2D, as required by VTK’s vtkCubeSource.

  • The center of the VTK cube/rectangle is set to anchor + 0.5 * box_size.

classmethod create(key: str, /, **kw: Any) T[source]#

Creates and returns an instance of a registered subclass.

This method looks up the subclass associated with the given key in the factory’s registry and then calls its constructor with the provided arguments.

Parameters:
  • key (str) – The registration key of the subclass to be created.

  • **kw (Any) – Arbitrary keyword arguments to be passed directly to the constructor of the registered subclass.

Returns:

An instance of the registered subclass.

Return type:

T

Raises:
  • KeyError – If the provided key is not found in the factory’s registry.

  • TypeError – If the provided **kw arguments do not match the signature of the registered subclass’s constructor.

Example

Given Foo factory and Bar registered:

>>> bar_instance = Foo.create("bar", value=42)
>>> print(bar_instance)
Bar(value=42)
classmethod register(key: str | None = None) Callable[[Type[T]], Type[T]][source]#

Registers a subclass with the factory’s registry.

This method returns a decorator that can be applied to a class to register it under a specific key.

Parameters:

key (str or None, optional) – The string key under which to register the subclass. If None, the lowercase name of the subclass itself will be used as the key.

Returns:

A decorator function that takes a class and registers it, returning the class unchanged.

Return type:

Callable[[Type[T]], Type[T]]

Raises:

ValueError – If the provided key (or the default class name) is already registered in the factory’s registry.

Example

Register a class named “MyComponent” under the key “mycomp”:

>>> @MyFactory.register("mycomp")
>>> class MyComponent:
>>>     ...

Register a class named “DefaultComponent” using its own name as the key:

>>> @MyFactory.register()
>>> class DefaultComponent:
>>>     ...
classmethod write(state: State, system: System, counter: int, directory: Path | str, binary: bool) int[source][source]#

Writes the simulation domain geometry to a VTK PolyData (.vtp) file.

The domain is represented as a vtkCubeSource, automatically adjusting for 2D or 3D simulation dimensions.

Parameters:
  • state (State) – The simulation state snapshot (NumPy-converted). (Note: Not directly used by DomainWriter, but required by base signature.)

  • system (System) – The simulation system configuration (NumPy-converted), providing domain.anchor and domain.box_size.

  • counter (int) – The unique integer identifier for this snapshot.

  • directory (pathlib.Path or str) – The target directory for the output file.

  • binary (bool) – If True, writes in binary mode; False for ASCII.

Returns:

The incremented counter (counter + 1).

Return type:

int