jaxdem#

JaxDEM module

class jaxdem.State(pos: Array, vel: Array, accel: Array, rad: Array, mass: Array, ID: Array, mat_id: Array, species_id: Array, fixed: Array)[source]#

Bases: object

Represents the complete simulation state for a system of N particles in 2D or 3D.

Notes

State is designed to support various data layouts:

  • Single snapshot:

    pos.shape = (N, dim) for particle properties (e.g., pos, vel, accel), and (N,) for scalar properties (e.g., rad, mass). In this case, batch_size is 1.

  • Batched states:

    pos.shape = (B, N, dim) for particle properties, and (B, N) for scalar properties. Here, B is the batch dimension (batch_size = pos.shape[0]).

  • Trajectories of a single simulation:

    pos.shape = (T, N, dim) for particle properties, and (T, N) for scalar properties. Here, T is the trajectory dimension.

  • Trajectories of batched states:

    pos.shape = (B, T_1, T_2, …, T_k, N, dim) for particle properties, and (B, T_1, T_2, …, T_k, N) for scalar properties.

    • The first dimension (i.e., pos.shape[0]) is always interpreted as the batch dimension (`B`).

    • All preceding leading dimensions (T_1, T_2, … T_k) are interpreted as trajectory dimensions

      and they are flattened at save time if there is more than 1 trajectory dimension.

The class is final and cannot be subclassed.

Example

Creating a simple 2D state for 4 particles:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>> import jax
>>>
>>> positions = jnp.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])
>>> state = jdem.State.create(pos=positions)
>>>
>>> print(f"Number of particles (N): {state.N}")
>>> print(f"Spatial dimension (dim): {state.dim}")
>>> print(f"Positions: {state.pos}")

Creating a batched state:

>>> batched_state = jax.vmap(lambda _: jdem.State.create(pos=positions))(jnp.arange(10))
>>>
>>> print(f"Batch size: {batched_state.batch_size}")  # 10
>>> print(f"Positions shape: {batched_state.pos.shape}")
pos: Array#

Array of particle positions. Shape is (…, N, dim).

vel: Array#

Array of particle velocities. Shape is (…, N, dim).

accel: Array#

Array of particle accelerations. Shape is (…, N, dim).

rad: Array#

Array of particle radii. Shape is (…, N).

mass: Array#

Array of particle masses. Shape is (…, N).

ID: Array#

Array of unique particle identifiers. Shape is (…, N).

mat_id: Array#

Array of material IDs for each particle. Shape is (…, N).

species_id: Array#

Array of species IDs for each particle. Shape is (…, N).

fixed: Array#

Boolean array indicating if a particle is fixed (immobile). Shape is (…, N).

property N: int[source]#

Number of particles in the state.

static add(state: State, pos: Array | ndarray | bool | number | bool | int | float | complex, *, vel: Array | ndarray | bool | number | bool | int | float | complex | None = None, accel: Array | ndarray | bool | number | bool | int | float | complex | None = None, rad: Array | ndarray | bool | number | bool | int | float | complex | None = None, mass: Array | ndarray | bool | number | bool | int | float | complex | None = None, ID: Array | ndarray | bool | number | bool | int | float | complex | None = None, mat_id: Array | ndarray | bool | number | bool | int | float | complex | None = None, species_id: Array | ndarray | bool | number | bool | int | float | complex | None = None, fixed: Array | ndarray | bool | number | bool | int | float | complex | None = None) State[source][source]#

Adds new particles to an existing State instance, returning a new State.

Parameters:
  • state (State) – The existing State to which particles will be added.

  • pos (jax.typing.ArrayLike) – Positions of the new particle(s). Shape (…, N_new, dim).

  • vel (jax.typing.ArrayLike or None, optional) – Velocities of the new particle(s). Defaults to zeros.

  • accel (jax.typing.ArrayLike or None, optional) – Accelerations of the new particle(s). Defaults to zeros.

  • rad (jax.typing.ArrayLike or None, optional) – Radii of the new particle(s). Defaults to ones.

  • mass (jax.typing.ArrayLike or None, optional) – Masses of the new particle(s). Defaults to ones.

  • ID (jax.typing.ArrayLike or None, optional) – IDs of the new particle(s). If None, new IDs are generated.

  • mat_id (jax.typing.ArrayLike or None, optional) – Material IDs of the new particle(s). Defaults to zeros.

  • species_id (jax.typing.ArrayLike or None, optional) – Species IDs of the new particle(s). Defaults to zeros.

  • fixed (jax.typing.ArrayLike or None, optional) – Fixed status of the new particle(s). Defaults to all False.

Returns:

A new State instance containing all particles from the original state plus the newly added particles.

Return type:

State

Raises:
  • ValueError – If the created new particle state or the merged state is invalid.

  • AssertionError – If batch size or dimension mismatch between existing state and new particles.

Example

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> # Initial state with 4 particles
>>> state = jdem.State.create(pos=jnp.zeros((4, 2)))
>>> print(f"Original state N: {state.N}, IDs: {state.ID}")
>>>
>>> # Add a single new particle
>>> state_with_added_particle = jdem.State.add(
...     state,
...     pos=jnp.array([[10.0, 10.0]]),
...     rad=jnp.array([0.5]),
...     mass=jnp.array([2.0]),
... )
>>> print(f"New state N: {state_with_added_particle.N}, IDs: {state_with_added_particle.ID}")
>>> print(f"New particle position: {state_with_added_particle.pos[-1]}")
>>>
>>> # Add multiple new particles
>>> state_multiple_added = jdem.State.add(
...     state,
...     pos=jnp.array([[10.0, 10.0], [11.0, 11.0], [12.0, 12.0]]),
... )
>>> print(f"State with multiple added N: {state_multiple_added.N}, IDs: {state_multiple_added.ID}")
property batch_size: int[source]#

Return the batch size of the state.

static create(pos: Array | ndarray | bool | number | bool | int | float | complex, *, vel: Array | ndarray | bool | number | bool | int | float | complex | None = None, accel: Array | ndarray | bool | number | bool | int | float | complex | None = None, rad: Array | ndarray | bool | number | bool | int | float | complex | None = None, mass: Array | ndarray | bool | number | bool | int | float | complex | None = None, ID: Array | ndarray | bool | number | bool | int | float | complex | None = None, mat_id: Array | ndarray | bool | number | bool | int | float | complex | None = None, species_id: Array | ndarray | bool | number | bool | int | float | complex | None = None, fixed: Array | ndarray | bool | number | bool | int | float | complex | None = None) State[source][source]#

Factory method to create a new State instance.

This method handles default values and ensures consistent array shapes for all state attributes.

Parameters:
  • pos (jax.typing.ArrayLike) – Initial positions of particles. Expected shape: (…, N, dim).

  • vel (jax.typing.ArrayLike or None, optional) – Initial velocities of particles. If None, defaults to zeros. Expected shape: (…, N, dim).

  • accel (jax.typing.ArrayLike or None, optional) – Initial accelerations of particles. If None, defaults to zeros. Expected shape: (…, N, dim).

  • rad (jax.typing.ArrayLike or None, optional) – Radii of particles. If None, defaults to ones. Expected shape: (…, N).

  • mass (jax.typing.ArrayLike or None, optional) – Masses of particles. If None, defaults to ones. Expected shape: (…, N).

  • ID (jax.typing.ArrayLike or None, optional) – Unique identifiers for particles. If None, defaults to jnp.arange(). Expected shape: (…, N).

  • mat_id (jax.typing.ArrayLike or None, optional) – Material IDs for particles. If None, defaults to zeros. Expected shape: (…, N).

  • species_id (jax.typing.ArrayLike or None, optional) – Species IDs for particles. If None, defaults to zeros. Expected shape: (…, N).

  • fixed (jax.typing.ArrayLike or None, optional) – Boolean array indicating fixed particles. If None, defaults to all False. Expected shape: (…, N).

Returns:

A new State instance with all attributes correctly initialized and shaped.

Return type:

State

Raises:

ValueError – If the created State is not valid.

Example

Creating a 3D state for 5 particles:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> my_pos = jnp.array([[0.,0.,0.], [1.,0.,0.], [0.,1.,0.], [0.,0.,1.], [1.,1.,1.]])
>>> my_rad = jnp.array([0.5, 0.5, 0.5, 0.5, 0.5])
>>> my_mass = jnp.array([1.0, 1.0, 1.0, 1.0, 1.0])
>>>
>>> state_5_particles = jdem.State.create(pos=my_pos, rad=my_rad, mass=my_mass)
>>> print(f"Shape of positions: {state_5_particles.pos.shape}")
>>> print(f"Radii: {state_5_particles.rad}")
property dim: int[source]#

Spatial dimension of the simulation.

property is_valid: bool[source]#

Check if the internal representation of the State is consistent.

Verifies that:

  • The spatial dimension (dim) is either 2 or 3.

  • All position-like arrays (pos, vel, accel) have the same shape.

  • All scalar-per-particle arrays (rad, mass, ID, mat_id, species_id, fixed) have a shape consistent with pos.shape[:-1].

Raises:

AssertionError – If any shape inconsistency is found.

static merge(state1: State, state2: State) State[source][source]#

Merges two State instances into a single new State.

This method concatenates the particles from state2 onto state1. Particle IDs in state2 are shifted to ensure uniqueness in the merged state.

state1State

The first State instance. Its particles will appear first in the merged state.

state2State

The second State instance. Its particles will be appended to the first.

State

A new State instance containing all particles from both input states.

AssertionError

If either input state is invalid, or if there is a mismatch in spatial dimension (dim) or batch size (batch_size).

ValueError

If the resulting merged state is somehow invalid.

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> state_a = jdem.State.create(pos=jnp.array([[0.0, 0.0], [1.0, 1.0]]), ID=jnp.array([0, 1]))
>>> state_b = jdem.State.create(pos=jnp.array([[2.0, 2.0], [3.0, 3.0]]), ID=jnp.array([0, 1]))
>>> merged_state = jdem.State.merge(state_a, state_b)
>>>
>>> print(f"Merged state N: {merged_state.N}")  # Expected: 4
>>> print(f"Merged state positions:
{merged_state.pos}”)
>>> print(f"Merged state IDs: {merged_state.ID}")  # Expected: [0, 1, 2, 3]
static stack(states: Sequence[State]) State[source][source]#

Concatenates a sequence of State snapshots into a trajectory along axis 0.

This method is useful for collecting simulation snapshots over time into a single State object where the leading dimension represents time.

Parameters:

states (Sequence[State]) – A sequence (e.g., list, tuple) of State instances to be stacked.

Returns:

A new State instance where each attribute is a JAX array with an additional leading dimension representing the stacked trajectory. For example, if input pos was (N, dim), output pos will be (T, N, dim).

Return type:

State

Raises:
  • ValueError – If the input states sequence is empty. If the stacked State is invalid.

  • AssertionError – If any input state is invalid, or if there is a mismatch in spatial dimension (dim), batch size (batch_size), or number of particles (N) between the states in the sequence.

Notes

  • No ID shifting is performed because the leading axis represents time (or another batch dimension), not new particles.

Example

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> # Create a sequence of 3 simple 2D snapshots
>>> snapshot1 = jdem.State.create(pos=jnp.array([[0.,0.], [1.,1.]]), vel=jnp.array([[0.1,0.], [0.0,0.1]]))
>>> snapshot2 = jdem.State.create(pos=jnp.array([[0.1,0.], [1.,1.1]]), vel=jnp.array([[0.1,0.], [0.0,0.1]]))
>>> snapshot3 = jdem.State.create(pos=jnp.array([[0.2,0.], [1.,1.2]]), vel=jnp.array([[0.1,0.], [0.0,0.1]]))
>>>
>>> trajectory_state = State.stack([snapshot1, snapshot2, snapshot3])
>>>
>>> print(f"Trajectory positions shape: {trajectory_state.pos.shape}") # Expected: (3, 2, 2)
>>> print(f"Positions at time step 0:\n{trajectory_state.pos[0]}")
>>> print(f"Positions at time step 1:\n{trajectory_state.pos[1]}")
class jaxdem.System(integrator: ~jaxdem.integrators.Integrator, collider: ~jaxdem.colliders.Collider, domain: ~jaxdem.domains.Domain, force_model: ~jaxdem.forces.ForceModel, mat_table: ~jaxdem.materials.materialTable.MaterialTable, dt: ~jax.Array, time: ~jax.Array, dim: ~jax.Array, step_count: ~jax.Array = <factory>)[source]#

Bases: object

Encapsulates the entire simulation configuration.

Notes

  • The System object is designed to be JIT-compiled for efficient execution.

  • The System class is a frozen dataclass, meaning its attributes cannot be changed after instantiation.

Example

Creating a basic 2D simulation system:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> # Create a System instance
>>> sim_system = jdem.System.create(
>>>     dim=state.dim,
>>>     dt=0.001,
>>>     integrator_type="euler",
>>>     collider_type="naive",
>>>     domain_type="free",
>>>     force_model_type="spring",
>>>     # You can pass keyword arguments to component constructors via '_kw' dicts
>>>     domain_kw=dict(box_size=jnp.array([5.0, 5.0]), anchor=jnp.array([0.0, 0.0]))
>>> )
>>>
>>> print(f"System integrator: {sim_system.integrator.__class__.__name__}")
>>> print(f"System force model: {sim_system.force_model.__class__.__name__}")
>>> print(f"Domain box size: {sim_system.domain.box_size}")
integrator: Integrator#

Instance of jaxdem.Integrator that advances the simulation state in time.

collider: Collider#

Instance of jaxdem.Collider that performs contact detection and computes inter-particle forces and potential energies.

domain: Domain#

Instance of jaxdem.Domain that defines the simulation boundaries, displacement rules, and boundary conditions.

force_model: ForceModel#

Instance of jaxdem.ForceModel that defines the physical laws for inter-particle interactions.

mat_table: MaterialTable#

Instance of jaxdem.MaterialTable holding material properties and pairwise interaction parameters.

dt: Array#

The global simulation time step \(\Delta t\).

time: Array#

Elapsed simulation time.

dim: Array#

Spatial dimension of the system.

step_count: Array#

Number of integration steps that have been performed.

static create(dim: int, *, dt: float = 0.005, time: float = 0.0, integrator_type: str = 'euler', collider_type: str = 'naive', domain_type: str = 'free', force_model_type: str = 'spring', mat_table: MaterialTable | None = None, integrator_kw: Dict[str, Any] | None = None, collider_kw: Dict[str, Any] | None = None, domain_kw: Dict[str, Any] | None = None, force_model_kw: Dict[str, Any] | None = None) System[source][source]#

Factory method to create a System instance with specified components.

Parameters:
  • dim (int) – The spatial dimension of the simulation (2 or 3).

  • dt (float, optional) – The global simulation time step.

  • integrator_type (str, optional) – The registered type string for the jaxdem.Integrator to use.

  • collider_type (str, optional) – The registered type string for the jaxdem.Collider to use.

  • domain_type (str, optional) – The registered type string for the jaxdem.Domain to use.

  • force_model_type (str, optional) – The registered type string for the jaxdem.ForceModel to use.

  • mat_table (MaterialTable or None, optional) – An optional pre-configured jaxdem.MaterialTable. If None, a default jaxdem.MaterialTable will be created with one generic elastic material and “harmonic” jaxdem.MaterialMatchmaker.

  • integrator_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected Integrator type.

  • collider_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected Collider type.

  • domain_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected Domain type.

  • force_model_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected ForceModel type.

Returns:

A fully configured System instance ready for simulation.

Return type:

System

Raises:
  • KeyError – If a specified *_type is not registered in its respective factory, or if the mat_table is missing properties required by the force_model.

  • TypeError – If constructor keyword arguments are invalid for any component.

  • AssertionError – If the domain_kw ‘box_size’ or ‘anchor’ shapes do not match the dim.

Example

Creating a 3D system with reflective boundaries and a custom dt:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> system_reflect = jdem.System.create(
>>>     dim=3,
>>>     dt=0.0005,
>>>     domain_type="reflect",
>>>     domain_kw=dict(box_size=jnp.array([20.0, 20.0, 20.0]), anchor=jnp.array([-10.0, -10.0, -10.0])),
>>>     force_model_type="spring",
>>> )
>>> print(f"System dt: {system_reflect.dt}")
>>> print(f"Domain type: {system_reflect.domain.__class__.__name__}")

Creating a system with a pre-defined MaterialTable:

>>> custom_mat_kw = dict(young=2.0e5, poisson=0.25)
>>> custom_material = jdem.Material.create("custom_mat", **custom_mat_kw)
>>> custom_mat_table = jdem.MaterialTable.from_materials(
...     [custom_material], matcher=jdem.MaterialMatchmaker.create("linear")
... )
>>>
>>> system_custom_mat = jdem.System.create(
...     dim=2,
...     mat_table=custom_mat_table,
...     force_model_type="spring"
... )
static step(state: State, system: System, *, n: int = 1, batched: bool = False) Tuple['State', 'System'][source][source]#

Advances the simulation state by n time steps.

This method provides a convenient way to run multiple integration steps. For a single step (n=1), it directly calls the integrator’s step method. For multiple steps (n > 1), it uses an optimized internal loop based on jax.lax.scan to maintain JIT-compilation efficiency.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The current system configuration.

  • n (int, optional) – The number of integration steps to perform. Defaults to 1. This argument must be static.

Returns:

A tuple containing the final State and System after n integration steps.

Return type:

Tuple[State, System]

Raises:

ValueError – If n is non-positive. (Implicit by jax.lax.scan length check).

Example

>>> import jaxdem as jdem
>>>
>>> state = jdem.utils.grid_state(n_per_axis=(1, 1), spacing=1.0, radius=0.1)
>>> system = jdem.System.create(dim=2, dt=0.01)
>>>
>>> # Advance by 1 step
>>> state_after_1_step, system_after_1_step = jdem.System.step(state, system)
>>> print("Position after 1 step:", state_after_1_step.pos[0])
>>>
>>> # Advance by 10 steps
>>> state_after_10_steps, system_after_10_steps = jdem.System.step(state, system, n=10)
>>> print("Position after 10 steps:", state_after_10_steps.pos[0])
static trajectory_rollout(state: State, system: System, *, n: int, stride: int = 1, batched: bool = False) Tuple['State', 'System', Tuple['State', 'System']][source][source]#

Rolls the system forward for a specified number of frames, collecting a trajectory.

This method performs n * stride total simulation steps, but it saves the State and System every stride steps, returning a trajectory of n snapshots. This is highly efficient for data collection within JAX as it leverages jax.lax.scan.

stateState

The initial state of the simulation.

systemSystem

The initial system configuration.

nint

The number of frames (snapshots) to collect in the trajectory. This argument must be static.

strideint, optional

The number of integration steps to advance between each collected frame. Defaults to 1, meaning every step is collected. This argument must be static.

Tuple[State, System, Tuple[State, System]]

A tuple containing:

  • final_state:

    The State object at the end of the rollout.

  • final_system:

    The System object at the end of the rollout.

  • trajectory:

    A tuple of State and System objects, where each leaf array has an additional leading axis of size n representing the trajectory. The State and System objects within trajectory are structured as if created by jaxdem.State.stack() and a similar jaxdem.System stack.

ValueError

If n or stride are non-positive. (Implicit by jax.lax.scan length)

>>> import jaxdem as jdem
>>>
>>> state = jdem.utils.grid_state(n_per_axis=(1, 1), spacing=1.0, radius=0.1)
>>> system = jdem.System.create(dim=2, dt=0.01)
>>>
>>> # Roll out for 10 frames, saving every 5 steps
>>> final_state, final_system, traj = jdem.System.trajectory_rollout(
...     state, system, n=10, stride=5
... )
>>>
>>> print(f"Total simulation steps performed: {10 * 5}")
>>> print(f"Trajectory length (number of frames): {traj[0].pos.shape[0]}")  # traj[0] is the state part of the trajectory
>>> print(f"First frame position:
{traj[0].pos[0]}”)
>>> print(f"Last frame position:
{traj[0].pos[-1]}”)
>>> print(f"Final state position (should match last frame):

{final_state.pos}”)

class jaxdem.VTKWriter(writers: ~typing.List[str] = <factory>, directory: ~pathlib.Path = PosixPath('frames'), binary: bool = True, clean: bool = True, save_every: int = 1, max_queue_size: int = 512, max_workers: int | None = None, _counter: int = 0, _writer_classes: ~typing.List = <factory>, _manifest: ~typing.Dict = <factory>)[source]#

Bases: object

High-level front end for writing simulation data to VTK files.

This class orchestrates the conversion of JAX-based jaxdem.State and jaxdem.System pytrees into VTK files, handling batches, trajectories, and dispatch to registered jaxdem.VTKBaseWriter subclasses.

How leading axes are interpreted#

Let particle positions have shape (..., N, dim), where N is the number of particles and dim is 2 or 3. Define L = state.pos.ndim - 2, i.e., the number of leading axes before (N, dim).

  • L == 0 — single snapshot

    The input is one frame. It is written directly into frames/batch_00000000/ (no batching, no trajectory).

  • trajectory=False (default)

    All leading axes are treated as batch axes (not time). If multiple batch axes are present, they are flattened into a single batch axis: (B, N, dim) with B = prod(shape[:L]). Each batch b is written as a single snapshot under its own subdirectory frames/batch_XXXXXXXX/. No trajectory is implied.

    • Example: (B, N, dim) → B separate directories with one frame each.

    • Example: (B1, B2, N, dim) → flatten to (B1*B2, N, dim) and treat as above.

  • trajectory=True

    The axis given by trajectory_axis is swapped to the front (axis 0) and interpreted as time T. Any remaining leading axes are batch axes. If more than one non-time leading axis exists, they are flattened into a single batch axis so the data becomes (T, B, N, dim) with B = prod(other leading axes).

    • If there is only time (L == 1): (T, N, dim) — a single batch

      directory frames/batch_00000000/ contains a time series with T frames.

    • If there is time plus batching (L >= 2): (T, B, N, dim) — each

      batch b gets its own directory frames/batch_XXXXXXXX/ containing a time series (T frames) for that batch.

After these swaps/reshapes, dispatch is: - (N, dim) → single snapshot - (B, N, dim) → batches (no time) - (T, N, dim) → single batch with a trajectory - (T, B, N, dim) → per-batch trajectories

Concrete writers receive per-frame NumPy arrays; leaves in System are sliced/broadcast consistently with the current frame/batch.

writers: List[str]#

A list of strings specifying which registered VTKBaseWriter subclasses should be used for writing. If None, all available VTKBaseWriter subclasses will be used.

directory: Path#

The base directory where output VTK files will be saved. Subdirectories might be created within this path for batched outputs. Defaults to “frames”.

binary: bool#

If True, VTK files will be written in binary format. If False, files will be written in ASCII format. Defaults to True.

clean: bool#

If True, the directory will be completely emptied before any files are written. Defaults to True. This is useful for starting a fresh set of output frames.

save_every: int#

How often to write; writes on every save_every-th call to save().

max_queue_size: int#

The maximum number of scheduled writes allowed. 0 means unbounded.

max_workers: int | None#

Maximum number of worker threads for the internal thread pool.

block_until_ready()[source][source]#

Wait until all scheduled writer tasks complete.

This will wait for all pending futures, propagate exceptions (if any), and clear the pending set.

close()[source][source]#

Flush all pending tasks and shut down the internal thread pool. Safe to call multiple times.

save(state: State, system: System, *, trajectory: bool = False, trajectory_axis: int = 0, batch0: int = 0)[source][source]#

Schedule writing of a jaxdem.State / jaxdem.System pair to VTK files.

This public entry point interprets leading axes (batch vs. trajectory), performs any required axis swapping and flattening, and then writes the resulting frames using the registered writers. It also creates per-batch ParaView .pvd collections referencing the generated files.

Parameters:
  • state (State) – The simulation jaxdem.State object to be saved. Its array leaves must end with (N, dim).

  • system (System) – The jaxdem.System object corresponding to state. Leading axes must be compatible (or broadcastable) with those of state.

  • trajectory (bool, optional) – If True, interpret trajectory_axis as time and write a trajectory; if False, interpret the leading axis as batch.

  • trajectory_axis (int, optional) – The axis in state/system to treat as the trajectory (time) axis when trajectory=True. This axis is swapped to the front prior to writing.

  • batch0 (in) – Initial value of batch from where to start counting the batches.

class jaxdem.VTKBaseWriter[source]#

Bases: Factory, ABC

Abstract base class for writers that output simulation data.

Concrete subclasses implement the write method to specify how a given snapshot (jaxdem.State, jaxdem.System pair) is converted into a specific file format.

Example

To define a custom VTK writer, inherit from VTKBaseWriter and implement its abstract methods:

>>> @VTKBaseWriter.register("my_custom_vtk_writer")
>>> @dataclass(slots=True, flozen=True)
>>> class MyCustomVTKWriter(VTKBaseWriter):
        ...
abstractmethod classmethod write(state: State, system: System, filename: Path, binary: bool)[source][source]#

Write information from a simulation snapshot to a VTK PolyData file.

This abstract method is the core interface for all concrete VTK writers. Implementations should assume that all jax arrays are converted to numpy arrays before write is called.

Parameters:
  • state (State) – The simulation jaxdem.State snapshot to be written.

  • system (System) – The simulation jaxdem.System configuration.

  • filename (Path) – Target path where the VTK file should be saved. The caller guarantees that it exists.

  • binary (bool) – If True, the VTK file is written in binary mode; if False, it is written in ASCII (human-readable) mode.

class jaxdem.CheckpointWriter(directory: Path | str = PosixPath('checkpoints'), max_to_keep: int | None = None, save_every: int = 1)[source]#

Bases: object

Thin wrapper around Orbax checkpoint saving.

directory: Path | str#

The base directory where checkpoints will be saved.

max_to_keep: int | None#

Keep the last max_to_keep checkpoints. If None, everything is save.

save_every: int#

How often to write; writes on every save_every-th call to save().

checkpointer: CheckpointManager#

Orbax checkpoint manager for saving the checkpoints.

block_until_ready()[source][source]#

Wait for the checkpointer to finish.

close() None[source][source]#

Wait for the checkpointer to finish and close it.

save(state: State, system: System) None[source][source]#

Save a checkpoint for the provided state/system at a given step.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The current system configuration.

class jaxdem.CheckpointLoader(directory: Path = PosixPath('checkpoints'))[source]#

Bases: object

Thin wrapper around Orbax checkpoint restoring for jaxdem.state and jaxdem.system.

directory: Path#

The base directory where checkpoints will be saved.

checkpointer: CheckpointManager#

Orbax checkpoint manager for saving the checkpoints.

block_until_ready()[source][source]#
close() None[source][source]#
load(step: int | None = None) Tuple[State, System][source][source]#

Restore a checkpoint.

Parameters:

step (Optional[int]) –

  • If None, load the latest checkpoint.

  • Otherwise, load the specified step.

Returns:

A tuple containing the restored State and System.

Return type:

Tuple[State, System]

class jaxdem.CheckpointModelWriter(directory: Path | str = PosixPath('checkpoints'), max_to_keep: int | None = None, save_every: int = 1, clean: bool = True)[source]#

Bases: object

Thin wrapper around Orbax checkpoint saving for jaxdem.rl.models.Model.

directory: Path | str#

The base directory where checkpoints will be saved.

max_to_keep: int | None#

Keep the last max_to_keep checkpoints. If None, everything is save.

save_every: int#

How often to write; writes on every save_every-th call to save().

checkpointer: CheckpointManager#

Orbax checkpoint manager for saving the checkpoints.

clean: bool#

Wether to clean the directory

block_until_ready() None[source][source]#
close() None[source][source]#
save(model: Model, step: int) None[source][source]#

Save model at a step: stores model_state and JSON metadata. Assumes model.metadata includes JSON-serializable fields. We add model_type.

class jaxdem.CheckpointModelLoader(directory: Path = PosixPath('checkpoints'))[source]#

Bases: object

Thin wrapper around Orbax checkpoint restoring for jaxdem.rl.models.Model.

directory: Path#

The base directory where checkpoints will be saved.

checkpointer: CheckpointManager#

Orbax checkpoint manager for saving the checkpoints.

block_until_ready() None[source][source]#
close() None[source][source]#
latest_step() int | None[source][source]#
load(step: int | None = None)[source][source]#

Load a model from a given step (or the latest if None).

class jaxdem.Material[source]#

Bases: Factory

Abstract base class for defining materials.

Concrete subclasses of Material should define scalar or vector fields (e.g., young, poisson, mu) that represent specific physical properties of a material. These fields are then collected and managed by the MaterialTable.

Notes

  • Each field defined in a concrete Material subclass will become a named property in the MaterialTable.props dictionary.

Example

To define a custom material, inherit from Material

>>> @Material.register("my_custom_material")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True, frozen=True)
>>> class MyCustomMaterial(Material):
        ...
class jaxdem.MaterialTable(props: Dict[str, jax.Array], pair: Dict[str, jax.Array], matcher: MaterialMatchmaker)[source]#

Bases: object

A container for material properties, organized as Structures of Arrays (SoA) and pre-computed effective pair properties.

This class centralizes material data, allowing efficient access to scalar properties for individual materials and pre-calculated effective properties for material-pair interactions.

Notes

  • Scalar properties can be accessed directly using dot notation (e.g., material_table.young).

  • Effective pair properties can also be accessed directly using dot notation (e.g., material_table.young_eff).

Example

Creating a MaterialTable from multiple material types:

>>> import jax.numpy as jnp
>>> import jaxdem as jdem
>>>
>>> # Define different material instances
>>> mat1 = jdem.Material.create("elastic", young=1.0e4, poisson=0.3)
>>> mat2 = jdem.Material.create("elasticfrict", young=2.0e4, poisson=0.4, mu=0.5)
>>>
>>> # Create a MaterialTable using a linear matcher
>>> matcher_instance = jdem.MaterialMatchmaker.create("linear")
>>> mat_table = matcher_instance.from_materials(
>>>     [mat1, mat2],
>>>     matcher=matcher_instance
>>> )
props: Dict[str, jax.Array]#

A dictionary mapping scalar material property names (e.g., “young”, “poisson”, “mu”) to JAX arrays. Each array has shape (M,), where M is the total number of distinct material types present in the table.

pair: Dict[str, jax.Array]#

A dictionary mapping effective pair property names (e.g., “young_eff”, “mu_eff”) to JAX arrays. Each array has shape (M, M), representing the effective property for interactions between any two material types (M_i, M_j).

matcher: MaterialMatchmaker#

The jaxdem.MaterialMatchmaker instance that was used to compute the effective pair properties stored in the pair dictionary.

static from_materials(mats: Sequence[Material], *, matcher: MaterialMatchmaker, fill: float = 0.0) MaterialTable[source][source]#

Constructs a MaterialTable from a sequence of Material instances.

Parameters:
  • mats (Sequence[Material]) – A sequence of concrete Material instances. Each instance represents a distinct material type in the simulation. The order in this sequence defines their material IDs (0 to len(mats)-1).

  • matcher (MaterialMatchmaker) – The jaxdem.MaterialMatchmaker instance to be used for computing effective pair properties (e.g., harmonic mean, arithmetic mean).

  • fill (float, optional) – A fill value used for material properties that are not defined in a specific Material subclass (e.g., if an Elastic material is provided when an ElasticFriction is expected, mu would be filled with this value). Defaults to 0.0.

Returns:

A new MaterialTable instance containing the scalar properties and pre-computed effective pair properties for all provided materials.

Return type:

MaterialTable

Raises:

TypeError – If mats is not a sequence of Material instances.

class jaxdem.MaterialMatchmaker[source]#

Bases: Factory, ABC

Abstract base class for defining how to combine (mix) material properties.

Notes

Example

To define a custom matchmaker, inherit from MaterialMatchmaker and implement its abstract methods:

>>> @MaterialMatchmaker.register("myCustomForce")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True, frozen=True)
>>> class MyCustomMatchmaker(MaterialMatchmaker):
        ...
abstractmethod static get_effective_property(prop1: Array, prop2: Array) Array[source][source]#

Abstract method to compute the effective property value from two individual material properties.

Concrete implementations define the specific mixing rule

Parameters:
  • prop1 (jax.Array) – The property value from the first material. Can be a scalar or an array.

  • prop2 (jax.Array) – The property value from the second material. Can be a scalar or an array.

Returns:

A JAX array representing the effective property, computed from prop1 and prop2 according to the matchmaker’s specific rule.

Return type:

jax.Array

class jaxdem.ForceModel(required_material_properties: Tuple[str, ...] = (), laws: Tuple[ForceModel, ...] = ())[source]#

Bases: Factory, ABC

Abstract base class for defining inter-particle force laws and their corresponding potential energies.

Concrete subclasses implement specific force and energy models, such as linear springs, Hertzian contacts, etc.

Notes

  • The force() and energy() methods should correctly handle the case where i and j refer to the same particle (i.e., i == j). There is no guarantee that self-interaction calls will not occur.

Example

To define a custom force model, inherit from ForceModel and implement its abstract methods:

>>> @ForceModel.register("myCustomForce")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True, frozen=True)
>>> class MyCustomForce(ForceModel):
        ...
required_material_properties: Tuple[str, ...]#

A static tuple of strings specifying the material properties required by this force model.

These properties (e.g., ‘young_eff’, ‘restitution’, …) must be present in the System.mat_table for the model to function correctly. This is used for validation.

laws: Tuple['ForceModel', ...]#

A static tuple of other ForceModel instances that compose this force model.

This allows for creating composite force models (e.g., a total force being the sum of a spring force and a damping force).

abstractmethod static energy(i: int, j: int, state: State, system: System) jax.Array[source][source]#

Compute the potential energy of the interaction between particle \(i\) and particle \(j\).

Parameters:
  • i (int) – Index of the first particle.

  • j (int) – Index of the second particle.

  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

Returns:

Scalar JAX array representing the potential energy of the interaction between particles \(i\) and \(j\).

Return type:

jax.Array

abstractmethod static force(i: int, j: int, state: State, system: System) jax.Array[source][source]#

Compute the force vector acting on particle \(i\) due to particle \(j\).

Parameters:
  • i (int) – Index of the first particle (on which the force is acting).

  • j (int) – Index of the second particle (which is exerting the force).

  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

Returns:

Force vector acting on particle \(i\) due to particle \(j\). Shape (dim,).

Return type:

jax.Array

class jaxdem.Integrator[source]#

Bases: Factory, ABC

Abstract base class for defining the interface for time-stepping.

Example

To define a custom integrator, inherit from Integrator and implement its abstract methods:

>>> @Integrator.register("myCustomIntegrator")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True, frozen=True)
>>> class MyCustomIntegrator(Integrator):
        ...
abstractmethod static initialize(state: State, system: System) Tuple['State', 'System'][source][source]#

Some integration methods require an initialization step, for example LeapFrog. This function implements the interface for the initialization.

Parameters:
  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

Returns:

A tuple containing the updated State and System after the initialization.

Return type:

Tuple[State, System]

Example

>>> state, system = system.integrator.initialize(state, system)
abstractmethod static step(state: State, system: System) Tuple['State', 'System'][source][source]#

Advance the simulation state by one time step using a specific numerical integration method.

Parameters:
  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

Returns:

A tuple containing the updated State and System after one time step of integration.

Return type:

Tuple[State, System]

Notes

  • This method performs the following updates:
    1. Applies boundary conditions using jaxdem.Domain.shift().

    2. Computes forces and accelerations using jaxdem.Collider.compute_force().

    3. Updates velocities based on current acceleration.

    4. Updates positions based on the newly updated velocities.

  • Particles with state.fixed set to True will have their velocities and positions unaffected by the integration step.

Example

>>> state, system = system.integrator.step(state, system)
class jaxdem.Collider[source]#

Bases: Factory, ABC

The base interface for defining how contact detection and force computations are performed in a simulation.

Concrete subclasses of Collider implement the specific algorithms for calculating the interactions.

Notes

Self-interaction (i.e., calling the force/energy computation for i=j) is allowed, and the underlying force_model is responsible for correctly handling or ignoring this case.

Example

To define a custom collider, inherit from Collider, register it and implement its abstract methods:

>>> @Collider.register("CustomCollider")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True, frozen=True)
>>> class CustomCollider(Collider):
        ...

Then, instantiate it:

>>> jaxdem.Collider.create("CustomCollider", **custom_collider_kw)
abstractmethod static compute_force(state: State, system: System) Tuple['State', 'System'][source][source]#

Abstract method to compute the total force acting on each particle in the simulation.

Implementations should calculate inter-particle forces based on the current state and system configuration, then update the accel attribute of the state object with the resulting total acceleration for each particle.

TO DO: DEFINE HOW TO RESET THE FORCE AND HOW TO ADD FORCE EXTERNALLY

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object (with computed accelerations) and the System object.

Return type:

Tuple[State, System]

abstractmethod static compute_potential_energy(state: State, system: System) jax.Array[source][source]#

Abstract method to compute the total potential energy of the system.

Implementations should calculate the sum per particle of all potential energies present in the system based on the current state and system configuration.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A scalar JAX array representing the total potential energy of each particle.

Return type:

jax.Array

Example

>>> potential_energy = system.collider.compute_potential_energy(state, system)
>>> print(f"Potential energy per particle: {potential_energy:.4f}")
>>> print(potential_energy.shape") # (N, 1)
class jaxdem.Domain(box_size: Array, anchor: Array)[source]#

Bases: Factory, ABC

The base interface for defining the simulation domain and the effect of its boundary conditions.

The Domain class defines how:
  • Relative displacement vectors between particles are calculated.

  • Particles’ positions are “shifted” or constrained to remain within the defined simulation boundaries based on the boundary condition type.

Example

To define a custom domain, inherit from Domain and implement its abstract methods:

>>> @Domain.register("my_custom_domain")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True, frozen=True)
>>> class MyCustomDomain(Domain):
        ...
box_size: jax.Array#

Length of the simulation domain along each dimension.

anchor: jax.Array#

Anchor position (minimum coordinate) of the simulation domain.

classmethod Create(dim: int, box_size: Array | None = None, anchor: Array | None = None) Self[source][source]#

Default factory method for the Domain class.

This method constructs a new Domain instance with a box-shaped domain of the given dimensionality. If box_size or anchor are not provided, they are initialized to default values.

Parameters:
  • dim (int) – The dimensionality of the domain (e.g., 2, 3).

  • box_size (jax.Array, optional) – The size of the domain along each dimension. If not provided, defaults to an array of ones with shape (dim,).

  • anchor (jax.Array, optional) – The anchor (origin) of the domain. If not provided, defaults to an array of zeros with shape (dim,).

Returns:

A new instance of the Domain subclass with the specified or default configuration.

Return type:

Domain

Raises:

AssertionError – If box_size and anchor do not have the same shape.

abstractmethod static displacement(ri: jax.Array, rj: jax.Array, system: System) jax.Array[source][source]#

Computes the displacement vector between two particles \(r_i\) and \(r_j\), considering the domain’s boundary conditions.

Parameters:
  • ri (jax.Array) – Position vector of the first particle \(r_i\). Shape (dim,).

  • rj (jax.Array) – Position vector of the second particle \(r_j\). Shape (dim,).

  • system (System) – The configuration of the simulation, containing the domain instance.

Returns:

The displacement vector \(r_{ij} = r_i - r_j\), adjusted for boundary conditions. Shape (dim,).

Return type:

jax.Array

Example

>>> rij = system.domain.displacement(ri, rj, system)
periodic: ClassVar[bool] = False#

Whether the domain enforces periodic boundary conditions.

abstractmethod static shift(state: State, system: System) Tuple['State', 'System'][source][source]#

Applies boundary conditions to particles state.

This method updates the state based on the domain’s rules, ensuring particles remain within the simulation box or handle interactions at boundaries appropriately (e.g., reflection, wrapping).

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object adjusted by the boundary conditions and the System object.

Return type:

Tuple[State, System]

Example

>>> state, system = system.domain.shift(state, system)
class jaxdem.Factory[source]#

Bases: ABC

Base factory class for pluggable components. This abstract base class provides a mechanism for registering and creating subclasses based on a string key.

Notes

Each concrete subclass gets its own private registry. Keys are strings and not case sensitive.

Example

Use Factory as a base class for a specific component type (e.g., Foo):

>>> class Foo(Factory["Foo"], ABC):
>>>   ...

Register a concrete subclass of Foo:

>>> @Foo.register("bar")
>>> class bar:
>>>     ...

To instantiate the subclass instance:

>>> Foo.create("bar", **bar_kw)
classmethod create(key: str, /, **kw: Any) RootT[source][source]#

Creates and returns an instance of a registered subclass.

This method looks up the subclass associated with the given key in the factory’s registry and then calls its constructor with the provided arguments. If the subclass defines a Create method (capitalized), that method will be called instead of the constructor. This allows subclasses to validate or preprocess arguments before instantiation.

Parameters:
  • key (str) – The registration key of the subclass to be created.

  • **kw (Any) – Arbitrary keyword arguments to be passed directly to the constructor of the registered subclass.

Returns:

An instance of the registered subclass.

Return type:

T

Raises:
  • KeyError – If the provided key is not found in the factory’s registry.

  • TypeError – If the provided **kw arguments do not match the signature of the registered subclass’s constructor.

Example

Given Foo factory and Bar registered:

>>> bar_instance = Foo.create("bar", value=42)
>>> print(bar_instance)
Bar(value=42)
classmethod register(key: str | None = None) Callable[[Type[SubT]], Type[SubT]][source][source]#

Registers a subclass with the factory’s registry.

This method returns a decorator that can be applied to a class to register it under a specific key.

Parameters:

key (str or None, optional) – The string key under which to register the subclass. If None, the lowercase name of the subclass itself will be used as the key.

Returns:

A decorator function that takes a class and registers it, returning the class unchanged.

Return type:

Callable[[Type[T]], Type[T]]

Raises:

ValueError – If the provided key (or the default class name) is already registered in the factory’s registry.

Example

Register a class named “MyComponent” under the key “mycomp”:

>>> @MyFactory.register("mycomp")
>>> class MyComponent:
>>>     ...

Register a class named “DefaultComponent” using its own name as the key:

>>> @MyFactory.register()
>>> class DefaultComponent:
>>>     ...
class jaxdem.ForceRouter(required_material_properties: Tuple[str, ...] = (), laws: Tuple[ForceModel, ...] = (), table: Tuple[Tuple[ForceModel, ...], ...] = ())[source]#

Bases: ForceModel

Static species-to-force lookup table.

table: Tuple[Tuple['ForceModel', ...], ...]#
static energy(i, j, state, system)[source][source]#
static force(i, j, state, system)[source][source]#
static from_dict(S: int, mapping: dict[Tuple[int, int], ForceModel])[source][source]#
class jaxdem.LawCombiner(required_material_properties: Tuple[str, ...] = (), laws: Tuple[ForceModel, ...] = ())[source]#

Bases: ForceModel

Sum a tuple of elementary force laws.

static energy(i, j, state, system)[source][source]#
static force(i, j, state, system)[source][source]#
classmethod registry_name() str[source]#
property type_name: str[source]#

Modules

colliders

Collision-detection interfaces and implementations.

domains

Simulation domains and boundary-condition implementations.

factory

The factory defines and instantiates specific simulation components.

forces

Force-law interfaces and concrete implementations.

integrators

Time-integration interfaces and implementations.

material_matchmakers

Material mix rules and implementations.

materials

Interface for defining materials and the MaterialTable.

rl

JaxDEM reinforcement learning (RL) module.

state

Defines the simulation State.

system

Defines the simulation configuration and the tooling for driving the simulation.

utils

Utility functions used to set up simulations and analyze the output.

writers

Interface for defining data writers.