jaxdem#

JaxDEM module

final class jaxdem.State(pos_c: Array, pos_p: Array, vel: Array, force: Array, q: Quaternion, angVel: Array, torque: Array, rad: Array, volume: Array, mass: Array, inertia: Array, clump_ID: Array, deformable_ID: Array, unique_ID: Array, mat_id: Array, species_id: Array, fixed: Array)[source]#

Bases: object

Represents the complete simulation state for a system of N particles in 2D or 3D.

Notes

State is designed to support various data layouts:

  • Single snapshot:

    pos.shape = (N, dim) for particle properties (e.g., pos, vel, force), and (N,) for scalar properties (e.g., rad, mass). In this case, batch_size is 1.

  • Batched states:

    pos.shape = (B, N, dim) for particle properties, and (B, N) for scalar properties. Here, B is the batch dimension (batch_size = pos.shape[0]).

  • Trajectories of a single simulation:

    pos.shape = (T, N, dim) for particle properties, and (T, N) for scalar properties. Here, T is the trajectory dimension.

  • Trajectories of batched states:

    pos.shape = (B, T_1, T_2, …, T_k, N, dim) for particle properties, and (B, T_1, T_2, …, T_k, N) for scalar properties.

    • The first dimension (i.e., pos.shape[0]) is always interpreted as the batch dimension (`B`).

    • All preceding leading dimensions (T_1, T_2, … T_k) are interpreted as trajectory dimensions

      and they are flattened at save time if there is more than 1 trajectory dimension.

The class is final and cannot be subclassed.

Example

Creating a simple 2D state for 4 particles:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>> import jax
>>>
>>> positions = jnp.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])
>>> state = jdem.State.create(pos=positions)
>>>
>>> print(f"Number of particles (N): {state.N}")
>>> print(f"Spatial dimension (dim): {state.dim}")
>>> print(f"Positions: {state.pos}")

Creating a batched state:

>>> batched_state = jax.vmap(lambda _: jdem.State.create(pos=positions))(jnp.arange(10))
>>>
>>> print(f"Batch size: {batched_state.batch_size}")  # 10
>>> print(f"Positions shape: {batched_state.pos.shape}")
pos_c: Array#

Array of particle center of mass positions. Shape is (…, N, dim).

pos_p: Array#

Vector relative to the center of mass (pos_p = pos - pos_c) in the principal reference frame. This field should be constant. Shape is (…, N, dim).

vel: Array#

Array of particle center of mass velocities. Shape is (…, N, dim).

force: Array#

Array of particle forces. Shape is (…, N, dim).

q: Quaternion#

Quaternion representing the orientation of the particle.

angVel: Array#

Array of particle center of mass angular velocities. Shape is (…, N, 1 | 3) depending on 2D or 3D simulations.

torque: Array#

Array of particle torques. Shape is (…, N, 1 | 3) depending on 2D or 3D simulations.

rad: Array#

Array of particle radii. Shape is (…, N).

volume: Array#

Array of particle volumes (or areas if 2D). Shape is (…, N).

mass: Array#

Array of particle masses. Shape is (…, N).

inertia: Array#

Inertia tensor in the principal axis frame (…, N, 1 | 3) depending on 2D or 3D simulations.

clump_ID: Array#

Array of clump identifiers. Bodies with the same clump_ID are treated as part of the same rigid body. Shape is (…, N). IDs need to be between 0 and N.

deformable_ID: Array#

Array of deformable particle identifiers. Spheres (nodes) with the same deformable_ID are treated as part of the same deformable particle for collision masking purposes. Shape is (…, N). IDs need to be between 0 and N.

unique_ID: Array#

Array of unique particle identifiers. No ID should be repeated. Shape is (…, N). IDs need to be between 0 and N.

mat_id: Array#

Array of material IDs for each particle. Shape is (…, N).

species_id: Array#

Array of species IDs for each particle. Shape is (…, N).

fixed: Array#

Boolean array indicating if a particle is fixed (immobile). Shape is (…, N).

property N: int[source]#

Number of particles in the state.

property dim: int[source]#

Spatial dimension of the simulation.

property shape: Tuple[int, ...][source]#

Shape of the position array pos_c, e.g. (N, dim) or (B, N, dim).

property batch_size: int[source]#

Return the batch size of the state.

property pos: Array[source]#

Returns the position of each sphere in the state. pos_c is the center of mass and pos_p is the vector relative to the center of mass in the principal reference frame, such that pos = pos_c + R(q) @ pos_p where R(q) rotates pos_p to the lab frame.

property is_valid: bool[source]#

Check if the internal representation of the State is consistent.

Verifies that:

  • The spatial dimension (dim) is either 2 or 3.

  • All position-like arrays (pos_c, pos_p, vel, force) have the same shape.

  • All angular-like arrays (angVel, torque, inertia) have the same shape.

  • All scalar-per-particle arrays (rad, mass, clump_ID, deformable_ID, mat_id, species_id, fixed) have a shape consistent with pos.shape[:-1].

Raises:

AssertionError – If any shape inconsistency is found.

static create(pos: ArrayLike, *, pos_p: ArrayLike | None = None, vel: ArrayLike | None = None, force: ArrayLike | None = None, q: Quaternion | None | ArrayLike | None = None, angVel: ArrayLike | None = None, torque: ArrayLike | None = None, rad: ArrayLike | None = None, volume: ArrayLike | None = None, mass: ArrayLike | None = None, inertia: ArrayLike | None = None, clump_ID: ArrayLike | None = None, deformable_ID: ArrayLike | None = None, mat_id: ArrayLike | None = None, species_id: ArrayLike | None = None, fixed: ArrayLike | None = None, mat_table: MaterialTable | None = None) State[source][source]#

Factory method to create a new State instance.

This method handles default values and ensures consistent array shapes for all state attributes.

Parameters:
  • pos (jax.typing.ArrayLike) – Array of particle center of mass positions, equivalent to state.pos_c. Expected shape: (…, N, dim).

  • pos_p (jax.typing.ArrayLike) – Vector relative to the center of mass (pos_p = pos - pos_c) in the principal reference frame. This field should be constant. Shape is (…, N, dim).

  • vel (jax.typing.ArrayLike or None, optional) – Initial velocities of particles. If None, defaults to zeros. Expected shape: (…, N, dim).

  • force (jax.typing.ArrayLike or None, optional) – Initial forces on particles. If None, defaults to zeros. Expected shape: (…, N, dim).

  • q (Quaternion or array-like, optional) – Initial particle orientations. If None, defaults to identity quaternions. Accepted shapes: quaternion objects or arrays of shape (…, N, 4) with components ordered as (w, x, y, z).

  • angVel (jax.typing.ArrayLike or None, optional) – Initial angular velocities of particles. If None, defaults to zeros. Expected shape: (…, N, 1) in 2D or (…, N, 3) in 3D.

  • torque (jax.typing.ArrayLike or None, optional) – Initial torques on particles. If None, defaults to zeros. Expected shape: (…, N, 1) in 2D or (…, N, 3) in 3D.

  • rad (jax.typing.ArrayLike or None, optional) – Radii of particles. If None, defaults to ones. Expected shape: (…, N).

  • volume (jax.typing.ArrayLike or None, optional) – Volume of particles (or area in 2D). If None, defaults to hypersphere volumes of the radii. Expected shape: (…, N).

  • mass (jax.typing.ArrayLike or None, optional) – Masses of particles. If None, defaults to ones. Ignored when mat_table is provided. Expected shape: (…, N).

  • inertia (jax.typing.ArrayLike or None, optional) – Moments of inertia in the principal axes frame. If None, defaults to solid disks (2D) or spheres (3D). Expected shape: (…, N, 1) in 2D or (…, N, 3) in 3D.

  • clump_ID (jax.typing.ArrayLike or None, optional) – Unique identifiers for clumps. If None, defaults to jnp.arange(). Expected shape: (…, N).

  • deformable_ID (jax.typing.ArrayLike or None, optional) – Unique identifiers for deformable particles. If None, defaults to jnp.arange(). Expected shape: (…, N).

  • mat_id (jax.typing.ArrayLike or None, optional) – Material IDs for particles. If None, defaults to zeros. Expected shape: (…, N).

  • species_id (jax.typing.ArrayLike or None, optional) – Species IDs for particles. If None, defaults to zeros. Expected shape: (…, N).

  • fixed (jax.typing.ArrayLike or None, optional) – Boolean array indicating fixed particles. If None, defaults to all False. Expected shape: (…, N).

  • mat_table (MaterialTable or None, optional) – Optional material table providing per-material densities. When provided, the mass argument is ignored and particle masses are computed from density and particle volume.

Returns:

A new State instance with all attributes correctly initialized and shaped.

Return type:

State

Raises:

ValueError – If the created State is not valid.

Example

Creating a 3D state for 5 particles:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> my_pos = jnp.array([[0.,0.,0.], [1.,0.,0.], [0.,1.,0.], [0.,0.,1.], [1.,1.,1.]])
>>> my_rad = jnp.array([0.5, 0.5, 0.5, 0.5, 0.5])
>>> my_mass = jnp.array([1.0, 1.0, 1.0, 1.0, 1.0])
>>>
>>> state_5_particles = jdem.State.create(pos=my_pos, rad=my_rad, mass=my_mass)
>>> print(f"Shape of positions: {state_5_particles.pos.shape}")
>>> print(f"Radii: {state_5_particles.rad}")
static merge(state1: State, state2: State | Sequence[State]) State[source][source]#

Merges multiple State instances into a single new State.

This method concatenates the particles from the provided state(s) onto state1. Particle clump_IDs, deformable_IDs, and unique_IDs are shifted to ensure uniqueness across the merged system.

state1State

The first State instance. Its particles will appear first in the merged state.

state2State or Sequence[State]

The second State or a list/tuple of State instances to append.

State

A new State instance containing all particles from both input states.

AssertionError

If either input state is invalid, or if there is a mismatch in spatial dimension (dim) or batch size (batch_size).

ValueError

If the resulting merged state is somehow invalid.

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> state_a = jdem.State.create(pos=jnp.array([[0.0, 0.0], [1.0, 1.0]]), clump_ID=jnp.array([0, 1]))
>>> state_b = jdem.State.create(pos=jnp.array([[2.0, 2.0], [3.0, 3.0]]), clump_ID=jnp.array([0, 1]))
>>> merged_state = jdem.State.merge(state_a, [state_b, state_b, state_b])
>>> merged_state = jdem.State.merge(state_a, state_b)
>>>
>>> print(f"Merged state N: {merged_state.N}")  # Expected: 4
>>> print(f"Merged state positions:
{merged_state.pos}”)
>>> print(f"Merged state clump_IDs: {merged_state.clump_ID}")  # Expected: [0, 1, 2, 3]
static add(state: State, pos: ArrayLike, *, pos_p: ArrayLike | None = None, vel: ArrayLike | None = None, force: ArrayLike | None = None, q: Quaternion | None | ArrayLike | None = None, angVel: ArrayLike | None = None, torque: ArrayLike | None = None, rad: ArrayLike | None = None, volume: ArrayLike | None = None, mass: ArrayLike | None = None, inertia: ArrayLike | None = None, clump_ID: ArrayLike | None = None, deformable_ID: ArrayLike | None = None, mat_id: ArrayLike | None = None, species_id: ArrayLike | None = None, fixed: ArrayLike | None = None, mat_table: MaterialTable | None = None) State[source][source]#

Adds new particles to an existing State instance, returning a new State.

Parameters:
  • state (State) – The existing State to which particles will be added.

  • pos (jax.typing.ArrayLike) – Array of particle center of mass positions, equivalent to state.pos_c. Expected shape: (…, N, dim).

  • pos_p (jax.typing.ArrayLike) – Vector relative to the center of mass (pos_p = pos - pos_c) in the principal reference frame. This field should be constant. Shape is (…, N, dim).

  • vel (jax.typing.ArrayLike or None, optional) – Velocities of the new particle(s). Defaults to zeros.

  • force (jax.typing.ArrayLike or None, optional) – Forces of the new particle(s). Defaults to zeros.

  • q (Quaternion or array-like, optional) – Initial orientations of the new particle(s). Defaults to identity quaternions.

  • angVel (jax.typing.ArrayLike or None, optional) – Angular velocities of the new particle(s). Defaults to zeros.

  • torque (jax.typing.ArrayLike or None, optional) – Torques of the new particle(s). Defaults to zeros.

  • rad (jax.typing.ArrayLike or None, optional) – Radii of the new particle(s). Defaults to ones.

  • volume (jax.typing.ArrayLike or None, optional) – Volume of the new particle(s) (or area in 2D). Defaults to hypersphere volumes of the radii.

  • mass (jax.typing.ArrayLike or None, optional) – Masses of the new particle(s). Defaults to ones. Ignored when a mat_table is provided.

  • inertia (jax.typing.ArrayLike or None, optional) – Moments of inertia of the new particle(s). Defaults to solid disks (2D) or spheres (3D).

  • clump_ID (jax.typing.ArrayLike or None, optional) – clump_IDs of the new clump(s). If None, new IDs are generated.

  • deformable_ID (jax.typing.ArrayLike or None, optional) – Unique identifiers for deformable particles. If None, defaults to jnp.arange(). Expected shape: (…, N).

  • mat_id (jax.typing.ArrayLike or None, optional) – Material IDs of the new particle(s). Defaults to zeros.

  • species_id (jax.typing.ArrayLike or None, optional) – Species IDs of the new particle(s). Defaults to zeros.

  • fixed (jax.typing.ArrayLike or None, optional) – Fixed status of the new particle(s). Defaults to all False.

  • mat_table (MaterialTable or None, optional) – Optional material table providing per-material densities. When provided, masses are computed from density and particle volume.

Returns:

A new State instance containing all particles from the original state plus the newly added particles.

Return type:

State

Raises:
  • ValueError – If the created new particle state or the merged state is invalid.

  • AssertionError – If batch size or dimension mismatch between existing state and new particles.

Example

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> # Initial state with 4 particles
>>> state = jdem.State.create(pos=jnp.zeros((4, 2)))
>>> print(f"Original state N: {state.N}, clump_IDs: {state.clump_ID}")
>>>
>>> # Add a single new particle
>>> state_with_added_particle = jdem.State.add(
...     state,
...     pos=jnp.array([[10.0, 10.0]]),
...     rad=jnp.array([0.5]),
...     mass=jnp.array([2.0]),
... )
>>> print(f"New state N: {state_with_added_particle.N}, clump_IDs: {state_with_added_particle.clump_ID}")
>>> print(f"New particle position: {state_with_added_particle.pos[-1]}")
>>>
>>> # Add multiple new particles
>>> state_multiple_added = jdem.State.add(
...     state,
...     pos=jnp.array([[10.0, 10.0], [11.0, 11.0], [12.0, 12.0]]),
... )
>>> print(f"State with multiple added N: {state_multiple_added.N}, clump_IDs: {state_multiple_added.clump_ID}")
static stack(states: Sequence[State]) State[source][source]#

Concatenates a sequence of State snapshots into a trajectory or batch along axis 0.

This method is useful for collecting simulation snapshots over time into a single State object where the leading dimension represents time or when preparing a batched state.

Parameters:

states (Sequence[State]) – A sequence (e.g., list, tuple) of State instances to be stacked.

Returns:

A new State instance where each attribute is a JAX array with an additional leading dimension representing the stacked trajectory. For example, if input pos was (N, dim), output pos will be (T, N, dim).

Return type:

State

Raises:
  • ValueError – If the input states sequence is empty. If the stacked State is invalid.

  • AssertionError – If any input state is invalid, or if there is a mismatch in spatial dimension (dim), batch size (batch_size), or number of particles (N) between the states in the sequence.

Notes

  • No clump_ID shifting is performed because the leading axis represents time (or another batch dimension), not new particles.

Example

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> # Create a sequence of 3 simple 2D snapshots
>>> snapshot1 = jdem.State.create(pos=jnp.array([[0.,0.], [1.,1.]]), vel=jnp.array([[0.1,0.], [0.0,0.1]]))
>>> snapshot2 = jdem.State.create(pos=jnp.array([[0.1,0.], [1.,1.1]]), vel=jnp.array([[0.1,0.], [0.0,0.1]]))
>>> snapshot3 = jdem.State.create(pos=jnp.array([[0.2,0.], [1.,1.2]]), vel=jnp.array([[0.1,0.], [0.0,0.1]]))
>>>
>>> trajectory_state = State.stack([snapshot1, snapshot2, snapshot3])
>>>
>>> print(f"Trajectory positions shape: {trajectory_state.pos.shape}") # Expected: (3, 2, 2)
>>> print(f"Positions at time step 0:\n{trajectory_state.pos[0]}")
>>> print(f"Positions at time step 1:\n{trajectory_state.pos[1]}")
static unstack(state: State) list[State][source][source]#

Split a stacked/batched State along the leading axis into a Python list.

This is the convenient inverse of State.stack():

  • If stacked = State.stack([s0, s1, …]), then State.unstack(stacked) returns [s0, s1, …].

Notes

  • The split is performed along axis 0 (the leading axis).

  • A single snapshot State (e.g. pos.shape == (N, dim)) cannot be unstacked with this method, because axis 0 would refer to particles, not snapshots.

static add_clump(state: State, pos: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, pos_p: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, vel: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, force: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, q: Quaternion | None | Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = None, angVel: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, torque: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, rad: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, volume: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, mass: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, inertia: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, deformable_ID: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, mat_id: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, species_id: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, fixed: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None) State[source][source]#

Adds a new clump consisting of multiple spheres to an existing State. Rigid body properties (velocity, mass, material, etc.) are broadcasted to all spheres in the new clump. The only per sphere properties that vary in a rigid body are pos_c, pos_p, and rad.

TO DO: broadcast the quaternion

Parameters:
  • state (State) – The existing State to which particles will be added.

  • pos (jax.typing.ArrayLike) – Array of particle center of mass positions, equivalent to state.pos_c. Expected shape: (…, N, dim).

  • pos_p (jax.typing.ArrayLike) – Vector relative to the center of mass (pos_p = pos - pos_c) in the principal reference frame. This field should be constant. Shape is (…, N, dim).

  • vel (jax.typing.ArrayLike or None, optional) – Velocities of the new particle(s). Defaults to zeros.

  • force (jax.typing.ArrayLike or None, optional) – Forces of the new particle(s). Defaults to zeros.

  • q (Quaternion or array-like, optional) – Initial orientations of the new particle(s). Defaults to identity quaternions.

  • angVel (jax.typing.ArrayLike or None, optional) – Angular velocities of the new particle(s). Defaults to zeros.

  • torque (jax.typing.ArrayLike or None, optional) – Torques of the new particle(s). Defaults to zeros.

  • rad (jax.typing.ArrayLike or None, optional) – Radii of the new particle(s). Defaults to ones.

  • volume (jax.typing.ArrayLike or None, optional) – Volume of the new particle(s) (or area in 2D). Defaults to hypersphere volumes of the radii.

  • mass (jax.typing.ArrayLike or None, optional) – Masses of the new particle(s). Defaults to ones.

  • inertia (jax.typing.ArrayLike or None, optional) – Moments of inertia of the new particle(s). Defaults to solid disks (2D) or spheres (3D).

  • deformable_ID (jax.typing.ArrayLike or None, optional) – Unique identifiers for deformable particles. If None, defaults to jnp.arange(). Expected shape: (…, N).

  • mat_id (jax.typing.ArrayLike or None, optional) – Material IDs of the new particle(s). Defaults to zeros.

  • species_id (jax.typing.ArrayLike or None, optional) – Species IDs of the new particle(s). Defaults to zeros.

  • fixed (jax.typing.ArrayLike or None, optional) – Fixed status of the new particle(s). Defaults to all False.

Returns:

A new State instance containing all particles from the original state plus the newly added particles.

Return type:

State

final class jaxdem.System(linear_integrator: LinearIntegrator | LinearMinimizer, rotation_integrator: RotationIntegrator | RotationMinimizer, collider: Collider, domain: Domain, force_manager: ForceManager, force_model: ForceModel, mat_table: MaterialTable, dt: jax.Array, time: jax.Array, dim: jax.Array, step_count: jax.Array, key: jax.Array)[source]#

Bases: object

Encapsulates the entire simulation configuration.

Notes

  • The System object is designed to be JIT-compiled for efficient execution.

  • The System dataclass is compatible with jax.jit(), so every field should remain JAX arrays for best performance.

Example

Creating a basic 2D simulation system:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> # Create a System instance
>>> sim_system = jdem.System.create(
>>>     state_shape=state.shape,
>>>     dt=0.001,
>>>     linear_integrator_type="euler",
>>>     rotation_integrator_type="spiral",
>>>     collider_type="naive",
>>>     domain_type="free",
>>>     force_model_type="spring",
>>>     # You can pass keyword arguments to component constructors via '_kw' dicts
>>>     domain_kw=dict(box_size=jnp.array([5.0, 5.0]), anchor=jnp.array([0.0, 0.0]))
>>> )
>>>
>>> print(f"System integrator: {sim_system.linear_integrator.__class__.__name__}")
>>> print(f"System force model: {sim_system.force_model.__class__.__name__}")
>>> print(f"Domain box size: {sim_system.domain.box_size}")
linear_integrator: LinearIntegrator | LinearMinimizer#

Instance of jaxdem.LinearIntegrator that advances the simulation linear state in time.

rotation_integrator: RotationIntegrator | RotationMinimizer#

Instance of jaxdem.RotationIntegrator that advances the simulation angular state in time.

collider: Collider#

Instance of jaxdem.Collider that performs contact detection and computes inter-particle forces and potential energies.

domain: Domain#

Instance of jaxdem.Domain that defines the simulation boundaries, displacement rules, and boundary conditions.

force_manager: ForceManager#

Instance of jaxdem.ForceManager that handles per particle forces like external forces and resets forces.

force_model: ForceModel#

Instance of jaxdem.ForceModel that defines the physical laws for inter-particle interactions.

mat_table: MaterialTable#

Instance of jaxdem.MaterialTable holding material properties and pairwise interaction parameters.

dt: jax.Array#

The global simulation time step \(\Delta t\).

time: jax.Array#

Elapsed simulation time.

dim: jax.Array#

Spatial dimension of the system.

step_count: jax.Array#

Number of integration steps that have been performed.

key: jax.Array#

PRNG key supporting stochastic functionality. Always update using split to ensure new numbers are generated.

static create(state_shape: Tuple[int, ...], *, dt: float = 0.005, time: float = 0.0, linear_integrator_type: str = 'verlet', rotation_integrator_type: str = 'verletspiral', collider_type: str = 'naive', domain_type: str = 'free', force_model_type: str = 'spring', force_manager_kw: Dict[str, Any] | None = None, mat_table: MaterialTable | None = None, linear_integrator_kw: Dict[str, Any] | None = None, rotation_integrator_kw: Dict[str, Any] | None = None, collider_kw: Dict[str, Any] | None = None, domain_kw: Dict[str, Any] | None = None, force_model_kw: Dict[str, Any] | None = None, seed: int = 0, key: Array | None = None) System[source][source]#

Factory method to create a System instance with specified components.

Parameters:
  • state_shape (Tuple) – Shape of the state tensors handled by the simulation. The penultimate dimension corresponds to the number of particles N and the last dimension corresponds to the spatial dimension dim.

  • dt (float, optional) – The global simulation time step.

  • linear_integrator_type (str, optional) – The registered type string for the jaxdem.integrators.LinearIntegrator used to evolve translational degrees of freedom.

  • rotation_integrator_type (str, optional) – The registered type string for the jaxdem.integrators.RotationIntegrator used to evolve angular degrees of freedom.

  • collider_type (str, optional) – The registered type string for the jaxdem.Collider to use.

  • domain_type (str, optional) – The registered type string for the jaxdem.Domain to use.

  • force_model_type (str, optional) – The registered type string for the jaxdem.ForceModel to use.

  • force_manager_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of ForceManager.

  • mat_table (MaterialTable or None, optional) – An optional pre-configured jaxdem.MaterialTable. If None, a default jaxdem.MaterialTable will be created with one generic elastic material and “harmonic” jaxdem.MaterialMatchmaker.

  • linear_integrator_kw (Dict[str, Any] or None, optional) – Keyword arguments forwarded to the constructor of the selected LinearIntegrator type.

  • rotation_integrator_kw (Dict[str, Any] or None, optional) – Keyword arguments forwarded to the constructor of the selected RotationIntegrator type.

  • collider_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected Collider type.

  • domain_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected Domain type.

  • force_model_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected ForceModel type.

  • seed (int, optional) – Integer seed used for random number generation. Defaults to 0.

  • key (jax.Array, optional) – Key used for the jax random number generation. Defaults to None, shadowed by seed.

Returns:

A fully configured System instance ready for simulation.

Return type:

System

Raises:
  • KeyError – If a specified *_type is not registered in its respective factory, or if the mat_table is missing properties required by the force_model.

  • TypeError – If constructor keyword arguments are invalid for any component.

  • AssertionError – If the domain_kw ‘box_size’ or ‘anchor’ shapes do not match the dim.

Example

Creating a 3D system with reflective boundaries and a custom dt:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> system_reflect = jdem.System.create(
>>>     state_shape=(N, 3),
>>>     dt=0.0005,
>>>     domain_type="reflect",
>>>     domain_kw=dict(box_size=jnp.array([20.0, 20.0, 20.0]), anchor=jnp.array([-10.0, -10.0, -10.0])),
>>>     force_model_type="spring",
>>> )
>>> print(f"System dt: {system_reflect.dt}")
>>> print(f"Domain type: {system_reflect.domain.__class__.__name__}")

Creating a system with a pre-defined MaterialTable:

>>> custom_mat_kw = dict(young=2.0e5, poisson=0.25)
>>> custom_material = jdem.Material.create("custom_mat", **custom_mat_kw)
>>> custom_mat_table = jdem.MaterialTable.from_materials(
...     [custom_material], matcher=jdem.MaterialMatchmaker.create("linear")
... )
>>>
>>> system_custom_mat = jdem.System.create(
...     state_shape=(N, 2),
...     mat_table=custom_mat_table,
...     force_model_type="spring"
... )
static trajectory_rollout(state: State, system: System, *, n: Optional[int] = None, stride: int = 1, strides: Optional[jax.Array] = None, save_fn: Callable[[State, System], Any] = <function _save_state_system>, unroll: int = 2) Tuple[State, System, Any][source][source]#

Roll the system forward while collecting saved outputs at each frame.

The rollout always stores one output per frame via save_fn(state, system). The output of save_fn must be a pytree. Frame spacing can be either: - constant (stride), or - variable (strides jax.Array).

Parameters:
  • state (State) – Initial state.

  • system (System) – Initial system configuration.

  • n (int, optional) – Number of saved frames. Required when strides is None. Ignored when strides is provided.

  • stride (int, optional) – Constant number of integration steps between consecutive saves. Used only when strides is None. Defaults to 1.

  • strides (jax.Array, optional) – Integer 1D array of per-frame integration strides. When provided, this overrides stride, and n is inferred from len(strides).

  • save_fn (Callable[[State, System], Any], optional) – Function called after each saved frame. Its return pytree is stacked along axis 0 across frames. Defaults to returning (state, system).

  • unroll (int, optional) – Unroll factor passed to the outer jax.lax.scan. Defaults to 2.

Returns:

(final_state, final_system, trajectory_like) where trajectory_like is the stacked output of save_fn.

Return type:

Tuple[State, System, Any]

Raises:

ValueError – If n is missing while strides is None, or if strides is not 1D.

Example

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> state = jdem.utils.grid_state(n_per_axis=(1, 1), spacing=1.0, radius=0.1)
>>> system = jdem.System.create(state_shape=state.shape, dt=0.01)
>>>
>>> # Constant stride: n is required
>>> final_state, final_system, traj = jdem.System.trajectory_rollout(
...     state, system, n=10, stride=5
... )
>>>
>>> # Variable strides: n inferred from len(strides)
>>> deltas = jnp.array([1, 2, 4, 8])
>>> final_state, final_system, traj = jdem.System.trajectory_rollout(
...     state, system, strides=deltas
... )
static step(state: State, system: System, *, n: int | jax.Array = 1) Tuple[State, System][source][source]#

Advance the simulation by n integration steps.

Parameters:
  • state (State) – Current state.

  • system (System) – Current system configuration.

  • n (int or jax.Array, optional) – Number of integration steps. May be a Python int or a scalar JAX array. Defaults to 1.

Returns:

(final_state, final_system) after n steps.

Return type:

Tuple[State, System]

Example

>>> # Advance by 10 steps
>>> state_after_10_steps, system_after_10_steps = jdem.System.step(state, system, n=10)
static stack(systems: Sequence[System]) System[source][source]#

Concatenates a sequence of System snapshots into a trajectory or batch along axis 0.

This method is useful for collecting simulation snapshots over time into a single System object where the leading dimension represents time or when preparing a batched system.

Parameters:

systems (Sequence[System]) – A sequence (e.g., list, tuple) of System instances to be stacked.

Returns:

A new System instance where each attribute is a JAX array with an additional leading dimension representing the stacked trajectory. For example, if input pos was (N, dim), output pos will be (T, N, dim).

Return type:

System

static unstack(system: System) list[System][source][source]#

Split a stacked/batched System along the leading axis into a Python list.

This is the convenient inverse of System.stack():

  • If stacked = System.stack([sys0, sys1, …]), then System.unstack(stacked) returns [sys0, sys1, …].

Notes

  • The split is performed along axis 0 (the leading axis).

  • A single snapshot System cannot be unstacked with this method.

class jaxdem.VTKWriter(writers: list[str] = <factory>, directory: Path = PosixPath('frames'), binary: bool = True, clean: bool = True, save_every: int = 1, max_queue_size: int = 512, max_workers: int | None = None, _counter: int = 0, _writer_classes: list[type[VTKBaseWriter]] = <factory>, _manifest: dict[str, ~typing.Any]=<factory>)[source]#

Bases: object

High-level front end for writing simulation data to VTK files.

This class orchestrates the conversion of JAX-based jaxdem.State and jaxdem.System pytrees into VTK files, handling batches, trajectories, and dispatch to registered jaxdem.VTKBaseWriter subclasses.

How leading axes are interpreted#

Let particle positions have shape (..., N, dim), where N is the number of particles and dim is 2 or 3. Define L = state.pos.ndim - 2, i.e., the number of leading axes before (N, dim).

  • L == 0 — single snapshot

    The input is one frame. It is written directly into frames/batch_00000000/ (no batching, no trajectory).

  • trajectory=False (default)

    All leading axes are treated as batch axes (not time). If multiple batch axes are present, they are flattened into a single batch axis: (B, N, dim) with B = prod(shape[:L]). Each batch b is written as a single snapshot under its own subdirectory frames/batch_XXXXXXXX/. No trajectory is implied.

    • Example: (B, N, dim) → B separate directories with one frame each.

    • Example: (B1, B2, N, dim) → flatten to (B1*B2, N, dim) and treat as above.

  • trajectory=True

    The axis given by trajectory_axis is swapped to the front (axis 0) and interpreted as time T. Any remaining leading axes are batch axes. If more than one non-time leading axis exists, they are flattened into a single batch axis so the data becomes (T, B, N, dim) with B = prod(other leading axes).

    • If there is only time (L == 1): (T, N, dim) — a single batch

      directory frames/batch_00000000/ contains a time series with T frames.

    • If there is time plus batching (L >= 2): (T, B, N, dim) — each

      batch b gets its own directory frames/batch_XXXXXXXX/ containing a time series (T frames) for that batch.

After these swaps/reshapes, dispatch is: - (N, dim) → single snapshot - (B, N, dim) → batches (no time) - (T, N, dim) → single batch with a trajectory - (T, B, N, dim) → per-batch trajectories

Concrete writers receive per-frame NumPy arrays; leaves in System are sliced/broadcast consistently with the current frame/batch.

writers: list[str]#

A list of strings specifying which registered VTKBaseWriter subclasses should be used for writing. If None, all available VTKBaseWriter subclasses will be used.

directory: Path#

The base directory where output VTK files will be saved. Subdirectories might be created within this path for batched outputs. Defaults to “frames”.

binary: bool#

If True, VTK files will be written in binary format. If False, files will be written in ASCII format. Defaults to True.

clean: bool#

If True, the directory will be completely emptied before any files are written. Defaults to True. This is useful for starting a fresh set of output frames.

save_every: int#

How often to write; writes on every save_every-th call to save().

max_queue_size: int#

The maximum number of scheduled writes allowed. 0 means unbounded.

max_workers: int | None#

Maximum number of worker threads for the internal thread pool.

close() None[source][source]#

Flush all pending tasks and shut down the internal thread pool. Safe to call multiple times.

block_until_ready() None[source][source]#

Wait until all scheduled writer tasks complete.

This will wait for all pending futures, propagate exceptions (if any), and clear the pending set.

save(state: State, system: System, *, trajectory: bool = False, trajectory_axis: int = 0, batch0: int = 0) None[source][source]#

Schedule writing of a jaxdem.State / jaxdem.System pair to VTK files.

This public entry point interprets leading axes (batch vs. trajectory), performs any required axis swapping and flattening, and then writes the resulting frames using the registered writers. It also creates per-batch ParaView .pvd collections referencing the generated files.

Parameters:
  • state (State) – The simulation jaxdem.State object to be saved. Its array leaves must end with (N, dim).

  • system (System) – The jaxdem.System object corresponding to state. Leading axes must be compatible (or broadcastable) with those of state.

  • trajectory (bool, optional) – If True, interpret trajectory_axis as time and write a trajectory; if False, interpret the leading axis as batch.

  • trajectory_axis (int, optional) – The axis in state/system to treat as the trajectory (time) axis when trajectory=True. This axis is swapped to the front prior to writing.

  • batch0 (in) – Initial value of batch from where to start counting the batches.

class jaxdem.VTKBaseWriter[source]#

Bases: Factory, ABC

Abstract base class for writers that output simulation data.

Concrete subclasses implement the write method to specify how a given snapshot (jaxdem.State, jaxdem.System pair) is converted into a specific file format.

Example

To define a custom VTK writer, inherit from VTKBaseWriter and implement its abstract methods:

>>> @VTKBaseWriter.register("my_custom_vtk_writer")
>>> @dataclass(slots=True)
>>> class MyCustomVTKWriter(VTKBaseWriter):
        ...
abstractmethod classmethod write(state: State, system: System, filename: Path, binary: bool) None[source][source]#

Write information from a simulation snapshot to a VTK PolyData file.

This abstract method is the core interface for all concrete VTK writers. Implementations should assume that all jax arrays are converted to numpy arrays before write is called.

Parameters:
  • state (State) – The simulation jaxdem.State snapshot to be written.

  • system (System) – The simulation jaxdem.System configuration.

  • filename (Path) – Target path where the VTK file should be saved. The caller guarantees that it exists.

  • binary (bool) – If True, the VTK file is written in binary mode; if False, it is written in ASCII (human-readable) mode.

class jaxdem.CheckpointWriter(directory: Path | str = PosixPath('checkpoints'), max_to_keep: int | None = None, save_every: int = 1)[source]#

Bases: object

Thin wrapper around Orbax checkpoint saving.

directory: Path | str = PosixPath('checkpoints')#

The base directory where checkpoints will be saved.

max_to_keep: int | None = None#

Keep the last max_to_keep checkpoints. If None, everything is saved.

save_every: int = 1#

How often to write; writes on every save_every-th call to save().

checkpointer: CheckpointManager#

Orbax checkpoint manager for saving the checkpoints.

save(state: State, system: System) None[source][source]#

Save a checkpoint for the provided state/system at a given step.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The current system configuration.

block_until_ready() None[source][source]#

Wait for the checkpointer to finish.

close() None[source][source]#

Wait for the checkpointer to finish and close it.

class jaxdem.CheckpointLoader(directory: Path = PosixPath('checkpoints'))[source]#

Bases: object

Thin wrapper around Orbax checkpoint restoring for jaxdem.state and jaxdem.system.

directory: Path = PosixPath('checkpoints')#

The base directory where checkpoints will be saved.

checkpointer: CheckpointManager#

Orbax checkpoint manager for saving the checkpoints.

load(step: int | None = None) Tuple[State, System][source][source]#

Restore a checkpoint.

Parameters:

step (Optional[int]) –

  • If None, load the latest checkpoint.

  • Otherwise, load the specified step.

Returns:

A tuple containing the restored State and System.

Return type:

Tuple[State, System]

block_until_ready() None[source][source]#
close() None[source][source]#
class jaxdem.CheckpointModelWriter(directory: Path | str = PosixPath('checkpoints'), max_to_keep: int | None = None, save_every: int = 1, clean: bool = True)[source]#

Bases: object

Thin wrapper around Orbax checkpoint saving for jaxdem.rl.models.Model.

directory: Path | str = PosixPath('checkpoints')#

The base directory where checkpoints will be saved.

max_to_keep: int | None = None#

Keep the last max_to_keep checkpoints. If None, everything is saved.

save_every: int = 1#

How often to write; writes on every save_every-th call to save().

checkpointer: CheckpointManager#

Orbax checkpoint manager for saving the checkpoints.

clean: bool = True#

Whether to clean the directory.

save(model: Model, step: int) None[source][source]#

Save model at a step: stores model_state and JSON metadata. Assumes model.metadata includes JSON-serializable fields. We add model_type.

block_until_ready() None[source][source]#
close() None[source][source]#
class jaxdem.CheckpointModelLoader(directory: Path = PosixPath('checkpoints'))[source]#

Bases: object

Thin wrapper around Orbax checkpoint restoring for jaxdem.rl.models.Model.

directory: Path = PosixPath('checkpoints')#

The base directory where checkpoints will be saved.

checkpointer: CheckpointManager#

Orbax checkpoint manager for saving the checkpoints.

load(step: int | None = None) Model[source][source]#

Load a model from a given step (or the latest if None).

latest_step() int | None[source][source]#
block_until_ready() None[source][source]#
close() None[source][source]#
class jaxdem.Material(density: float)[source]#

Bases: Factory

Abstract base class for defining materials.

Concrete subclasses of Material should define scalar or vector fields (e.g., young, poisson, mu) that represent specific physical properties of a material. These fields are then collected and managed by the MaterialTable.

Notes

  • Each field defined in a concrete Material subclass will become a named property in the MaterialTable.props dictionary.

Example

To define a custom material, inherit from Material

>>> @Material.register("my_custom_material")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class MyCustomMaterial(Material):
        ...
density: float#
class jaxdem.MaterialTable(props: Dict[str, Array], pair: Dict[str, Array], matcher: MaterialMatchmaker)[source]#

Bases: object

A container for material properties, organized as Structures of Arrays (SoA) and pre-computed effective pair properties.

This class centralizes material data, allowing efficient access to scalar properties for individual materials and pre-calculated effective properties for material-pair interactions.

Notes

  • Scalar properties can be accessed directly using dot notation (e.g., material_table.young).

  • Effective pair properties can also be accessed directly using dot notation (e.g., material_table.young_eff).

Example

Creating a MaterialTable from multiple material types:

>>> import jax.numpy as jnp
>>> import jaxdem as jdem
>>>
>>> # Define different material instances
>>> mat1 = jdem.Material.create("elastic", density=2500.0, young=1.0e4, poisson=0.3)
>>> mat2 = jdem.Material.create("elasticfrict", density=7800.0, young=2.0e4, poisson=0.4, mu=0.5)
>>>
>>> # Create a MaterialTable using a linear matcher
>>> matcher_instance = jdem.MaterialMatchmaker.create("linear")
>>> mat_table = matcher_instance.from_materials(
>>>     [mat1, mat2],
>>>     matcher=matcher_instance
>>> )
props: Dict[str, Array]#

A dictionary mapping scalar material property names (e.g., “young”, “poisson”, “mu”) to JAX arrays. Each array has shape (M,), where M is the total number of distinct material types present in the table.

pair: Dict[str, Array]#

A dictionary mapping effective pair property names (e.g., “young_eff”, “mu_eff”) to JAX arrays. Each array has shape (M, M), representing the effective property for interactions between any two material types (M_i, M_j).

matcher: MaterialMatchmaker#

The jaxdem.MaterialMatchmaker instance that was used to compute the effective pair properties stored in the pair dictionary.

static from_materials(mats: Sequence[Material], *, matcher: MaterialMatchmaker | None = None, fill: float = 0.0) MaterialTable[source][source]#

Constructs a MaterialTable from a sequence of Material instances.

Parameters:
  • mats (Sequence[Material]) – A sequence of concrete Material instances. Each instance represents a distinct material type in the simulation. The order in this sequence defines their material IDs (0 to len(mats)-1).

  • matcher (MaterialMatchmaker) – The jaxdem.MaterialMatchmaker instance to be used for computing effective pair properties (e.g., harmonic mean, arithmetic mean).

  • fill (float, optional) – A fill value used for material properties that are not defined in a specific Material subclass (e.g., if an Elastic material is provided when an ElasticFriction is expected, mu would be filled with this value). Defaults to 0.0.

Returns:

A new MaterialTable instance containing the scalar properties and pre-computed effective pair properties for all provided materials.

Return type:

MaterialTable

Raises:

TypeError – If mats is not a sequence of Material instances.

class jaxdem.MaterialMatchmaker[source]#

Bases: Factory, ABC

Abstract base class for defining how to combine (mix) material properties.

Notes

Example

To define a custom matchmaker, inherit from MaterialMatchmaker and implement its abstract methods:

>>> @MaterialMatchmaker.register("myCustomForce")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class MyCustomMatchmaker(MaterialMatchmaker):
        ...
abstractmethod static get_effective_property(prop1: Array, prop2: Array) Array[source][source]#

Abstract method to compute the effective property value from two individual material properties.

Concrete implementations define the specific mixing rule

Parameters:
  • prop1 (jax.Array) – The property value from the first material. Can be a scalar or an array.

  • prop2 (jax.Array) – The property value from the second material. Can be a scalar or an array.

Returns:

A JAX array representing the effective property, computed from prop1 and prop2 according to the matchmaker’s specific rule.

Return type:

jax.Array

class jaxdem.ForceModel(laws: Tuple[ForceModel, ...] = ())[source]#

Bases: Factory, ABC

Abstract base class for defining inter-particle force laws and their corresponding potential energies.

Concrete subclasses implement specific force and energy models, such as linear springs, Hertzian contacts, etc.

Notes

  • The force() and energy() methods should correctly handle the case where i and j refer to the same particle (i.e., i == j). There is no guarantee that self-interaction calls will not occur.

Example

To define a custom force model, inherit from ForceModel and implement its abstract methods:

>>> @ForceModel.register("myCustomForce")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class MyCustomForce(ForceModel):
        ...
laws: Tuple[ForceModel, ...]#

A static tuple of other ForceModel instances that compose this force model.

This allows for creating composite force models (e.g., a total force being the sum of a spring force and a damping force).

abstractmethod static force(i: int, j: int, pos: jax.Array, state: State, system: System) Tuple[jax.Array, jax.Array][source][source]#

Compute the force and torque vector acting on particle \(i\) due to particle \(j\).

Parameters:
  • i (int) – Index of the first particle (on which the interaction acts).

  • j (int) – Index of the second particle (which is exerting the interaction).

  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

Returns:

A tuple (force, torque) where force has shape (dim,) and torque has shape (1,) in 2D or (3,) in 3D.

Return type:

Tuple[jax.Array, jax.Array]

abstractmethod static energy(i: int, j: int, pos: jax.Array, state: State, system: System) jax.Array[source][source]#

Compute the potential energy of the interaction between particle \(i\) and particle \(j\).

Parameters:
  • i (int) – Index of the first particle.

  • j (int) – Index of the second particle.

  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

Returns:

Scalar JAX array representing the potential energy of the interaction between particles \(i\) and \(j\).

Return type:

jax.Array

property required_material_properties: Tuple[str, ...][source]#

A static tuple of strings specifying the material properties required by this force model.

These properties (e.g., ‘young_eff’, ‘restitution’, …) must be present in the System.mat_table for the model to function correctly. This is used for validation.

class jaxdem.Integrator[source]#

Bases: Factory, ABC

Abstract base class for defining the interface for time-stepping.

Example

To define a custom integrator, inherit from Integrator and implement its abstract methods:

>>> @Integrator.register("myCustomIntegrator")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class MyCustomIntegrator(Integrator):
        ...
static step_before_force(state: State, system: System) Tuple[State, System][source][source]#

Advance the simulation state before the force evaluation.

Parameters:
  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

Returns:

A tuple containing the updated State and System after one time step of integration.

Return type:

Tuple[State, System]

Note

  • This method donates state and system

static step_after_force(state: State, system: System) Tuple[State, System][source][source]#

Advance the simulation state after the force computation by one time step.

Parameters:
  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

Returns:

A tuple containing the updated State and System after one time step of integration.

Return type:

Tuple[State, System]

Note

  • This method donates state and system

static initialize(state: State, system: System) Tuple[State, System][source][source]#

Some integration methods require an initialization step, for example LeapFrog. This function implements the interface for the initialization.

Parameters:
  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

Returns:

A tuple containing the updated State and System after the initialization.

Return type:

Tuple[State, System]

Note

  • This method donates state and system

Example

>>> state, system = system.integrator.initialize(state, system)
class jaxdem.LinearIntegrator[source]#

Bases: Integrator

Namespace for translation/linear-time integrators.

Purpose#

Groups integrators that update linear state (e.g., position and velocity). Concrete methods (e.g., DirectEuler) should subclass this to register via the Factory and to signal they operate on linear kinematics.

class jaxdem.RotationIntegrator[source]#

Bases: Integrator

Namespace for rotation/angular-time integrators.

Purpose#

Groups integrators that update angular state (e.g., orientation, angular velocity). Concrete methods (e.g., DirectEulerRotation) should subclass this to register via the Factory and to signal they operate on rotational kinematics.

class jaxdem.Minimizer[source]#

Bases: Integrator, ABC

Abstract base class for energy minimizers.

Notes

  • Minimizer subclasses the generic Integrator interface, so it can be plugged in anywhere an Integrator is expected (e.g., as System.linear_integrator).

  • The default implementations of step_before_force, step_after_force, and initialize are inherited from Integrator and act as no-ops.

  • Concrete minimizers should typically override step_after_force to update the state based on the current forces in an energy-decreasing way.

class jaxdem.LinearMinimizer[source]#

Bases: Minimizer

Namespace for translation/linear-state minimizers.

Concrete minimizers (e.g., GradientDescent) should subclass this to signal that they operate on linear kinematics.

class jaxdem.RotationMinimizer[source]#

Bases: Minimizer

Namespace for rotational-state minimizers.

Concrete minimizers that relax orientations / angular DOFs should subclass this.

class jaxdem.Collider[source]#

Bases: Factory, ABC

The base interface for defining how contact detection and force computations are performed in a simulation.

Concrete subclasses of Collider implement the specific algorithms for calculating the interactions.

Notes

Self-interaction (i.e., calling the force/energy computation for i=j) is allowed, and the underlying force_model is responsible for correctly handling or ignoring this case.

Example

To define a custom collider, inherit from Collider, register it and implement its abstract methods:

>>> @Collider.register("CustomCollider")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class CustomCollider(Collider):
        ...

Then, instantiate it:

>>> jaxdem.Collider.create("CustomCollider", **custom_collider_kw)
static compute_force(state: State, system: System) Tuple[State, System][source][source]#

Abstract method to compute the total force acting on each particle in the simulation.

Implementations should calculate inter-particle forces and torques based on the current state and system configuration, then update the force and torque attributes of the state object with the resulting total force and torque for each particle.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object (with computed forces) and the System object.

Return type:

Tuple[State, System]

Note

  • This method donates state and system

static compute_potential_energy(state: State, system: System) jax.Array[source][source]#

Abstract method to compute the total potential energy of the system.

Implementations should calculate the sum per particle of all potential energies present in the system based on the current state and system configuration.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A scalar JAX array representing the total potential energy of each particle.

Return type:

jax.Array

Example

>>> potential_energy = system.collider.compute_potential_energy(state, system)
>>> print(f"Potential energy per particle: {potential_energy:.4f}")
>>> print(potential_energy.shape)  # (N,)
static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) Tuple[State, System, jax.Array, jax.Array][source][source]#

Build a neighbor list for the current collider.

This is primarily used by neighbor-list-based algorithms and diagnostics. Implementations should match the cell-list semantics:

  • Returns a neighbor list of shape (N, max_neighbors) padded with -1.

  • Neighbor indices must refer to the returned (possibly sorted) state.

  • Also returns an overflow boolean flag (True if any particle exceeded max_neighbors neighbors within the cutoff).

class jaxdem.Domain(box_size: Array, anchor: Array)[source]#

Bases: Factory, ABC

The base interface for defining the simulation domain and the effect of its boundary conditions.

The Domain class defines how:
  • Relative displacement vectors between particles are calculated.

  • Particles’ positions are “shifted” or constrained to remain within the defined simulation boundaries based on the boundary condition type.

Example

To define a custom domain, inherit from Domain and implement its abstract methods:

>>> @Domain.register("my_custom_domain")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class MyCustomDomain(Domain):
        ...
box_size: Array#

Length of the simulation domain along each dimension.

anchor: Array#

Anchor position (minimum coordinate) of the simulation domain.

property periodic: bool[source]#

Whether the domain enforces periodic boundary conditions.

classmethod Create(dim: int, box_size: Array | None = None, anchor: Array | None = None) Self[source][source]#

Default factory method for the Domain class.

This method constructs a new Domain instance with a box-shaped domain of the given dimensionality. If box_size or anchor are not provided, they are initialized to default values.

Parameters:
  • dim (int) – The dimensionality of the domain (e.g., 2, 3).

  • box_size (jax.Array, optional) – The size of the domain along each dimension. If not provided, defaults to an array of ones with shape (dim,).

  • anchor (jax.Array, optional) – The anchor (origin) of the domain. If not provided, defaults to an array of zeros with shape (dim,).

Returns:

A new instance of the Domain subclass with the specified or default configuration.

Return type:

Domain

Raises:

AssertionError – If box_size and anchor do not have the same shape.

static displacement(ri: jax.Array, rj: jax.Array, system: System) jax.Array[source][source]#

Computes the displacement vector between two particles \(r_i\) and \(r_j\), considering the domain’s boundary conditions.

Parameters:
  • ri (jax.Array) – Position vector of the first particle \(r_i\). Shape (dim,).

  • rj (jax.Array) – Position vector of the second particle \(r_j\). Shape (dim,).

  • system (System) – The configuration of the simulation, containing the domain instance.

Returns:

The displacement vector \(r_{ij} = r_i - r_j\), adjusted for boundary conditions. Shape (dim,).

Return type:

jax.Array

Example

>>> rij = system.domain.displacement(ri, rj, system)
static apply(state: State, system: System) Tuple[State, System][source][source]#

Applies boundary conditions during the simulation step.

This method updates the state based on the domain’s rules, ensuring particles handle interactions at boundaries appropriately (e.g., reflection).

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object adjusted by the boundary conditions and the System object.

Return type:

Tuple[State, System]

Note

  • Periodic boundary conditions do not require wrapping coordinates during time stepping,

but reflective boundaries require changing positions and velocities. To wrap positions for periodic boundaries so they are displayed correctly when saving, and other algorithms use the shift method. - This method donates state and system

Example

>>> state, system = system.domain.apply(state, system)
static shift(state: State, system: System) Tuple[State, System][source][source]#

This method updates the state based on the domain’s rules, ensuring particles remain within the simulation box or handle interactions at boundaries appropriately (e.g., reflection, wrapping).

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object adjusted by the boundary conditions and the System object.

Return type:

Tuple[State, System]

Example

>>> state, system = system.domain.shift(state, system)
class jaxdem.Factory[source]#

Bases: ABC

Base factory class for pluggable components. This abstract base class provides a mechanism for registering and creating subclasses based on a string key.

Notes

Each concrete subclass gets its own private registry. Keys are strings and not case sensitive.

Example

Use Factory as a base class for a specific component type (e.g., Foo):

>>> class Foo(Factory["Foo"], ABC):
>>>   ...

Register a concrete subclass of Foo:

>>> @Foo.register("bar")
>>> class bar:
>>>     ...

To instantiate the subclass instance:

>>> Foo.create("bar", **bar_kw)
classmethod registry_name() str[source][source]#

Returns the key under which this class is registered.

property type_name: str[source]#

Returns the key under which this instance’s class is registered.

classmethod register(key: str | None = None) Callable[[Type[SubT]], Type[SubT]][source][source]#

Registers a subclass with the factory’s registry.

This method returns a decorator that can be applied to a class to register it under a specific key.

Parameters:

key (str or None, optional) – The string key under which to register the subclass. If None, the lowercase name of the subclass itself will be used as the key.

Returns:

A decorator function that takes a class and registers it, returning the class unchanged.

Return type:

Callable[[Type[T]], Type[T]]

Raises:

ValueError – If the provided key (or the default class name) is already registered in the factory’s registry.

Example

Register a class named “MyComponent” under the key “mycomp”:

>>> @MyFactory.register("mycomp")
>>> class MyComponent:
>>>     ...

Register a class named “DefaultComponent” using its own name as the key:

>>> @MyFactory.register()
>>> class DefaultComponent:
>>>     ...
classmethod create(key: str, /, **kw: Any) RootT[source][source]#

Creates and returns an instance of a registered subclass.

This method looks up the subclass associated with the given key in the factory’s registry and then calls its constructor with the provided arguments. If the subclass defines a Create method (capitalized), that method will be called instead of the constructor. This allows subclasses to validate or preprocess arguments before instantiation.

Parameters:
  • key (str) – The registration key of the subclass to be created.

  • **kw (Any) – Arbitrary keyword arguments to be passed directly to the constructor of the registered subclass.

Returns:

An instance of the registered subclass.

Return type:

T

Raises:
  • KeyError – If the provided key is not found in the factory’s registry.

  • TypeError – If the provided **kw arguments do not match the signature of the registered subclass’s constructor.

Example

Given Foo factory and Bar registered:

>>> bar_instance = Foo.create("bar", value=42)
>>> print(bar_instance)
Bar(value=42)
class jaxdem.ForceRouter(laws: Tuple[ForceModel, ...] = (), table: Tuple[Tuple[ForceModel, ...], ...] = ())[source]#

Bases: ForceModel

Static species-to-force lookup table.

table: Tuple[Tuple[ForceModel, ...], ...]#
static from_dict(S: int, mapping: dict[Tuple[int, int], ForceModel]) ForceRouter[source][source]#
static force(i: int, j: int, pos: jax.Array, state: State, system: System) jax.Array[source][source]#
static energy(i: int, j: int, pos: jax.Array, state: State, system: System) jax.Array[source][source]#
class jaxdem.LawCombiner(laws: Tuple[ForceModel, ...] = ())[source]#

Bases: ForceModel

Sum a tuple of elementary force laws.

static force(i: int, j: int, pos: jax.Array, state: State, system: System) Tuple[jax.Array, jax.Array][source][source]#
static energy(i: int, j: int, pos: jax.Array, state: State, system: System) jax.Array[source][source]#
class jaxdem.ForceManager(gravity: jax.Array, external_force: jax.Array, external_force_com: jax.Array, external_torque: jax.Array, is_com_force: Tuple[bool, ...] = (), force_functions: Tuple[ForceFunction, ...] = (), energy_functions: Tuple[EnergyFunction | None, ...] = ())[source]#

Bases: object

Manage custom force contributions external to the collider. It also accumulates forces in the state after collider application, accounting for rigid bodies.

gravity: jax.Array#

Constant acceleration applied to all particles. Shape (dim,).

external_force: jax.Array#

Accumulated external force applied to all particles (at particle position). This buffer is cleared when apply() is invoked.

external_force_com: jax.Array#

Accumulated external force applied to Center of Mass (does not induce torque). This buffer is cleared when apply() is invoked.

external_torque: jax.Array#

Accumulated external torque applied to all particles. This buffer is cleared when apply() is invoked.

is_com_force: Tuple[bool, ...]#

Boolean array corresponding to force_functions with shape (n_forces,). If True, the force is applied to the Center of Mass (no induced torque). If False, the force is applied to the constituent particle (induces torque via lever arm).

force_functions: Tuple[ForceFunction, ...]#

Tuple of callables with signature (pos, state, system) returning per-particle force and torque arrays.

energy_functions: Tuple[EnergyFunction | None, ...]#

Tuple of callables (or None) with signature (pos, state, system) returning per-particle potential energy arrays. Corresponds to force_functions.

static create(state_shape: Tuple[int, ...], *, gravity: jax.Array | None = None, force_functions: Sequence[ForceFunction | Tuple[ForceFunction, bool] | Tuple[ForceFunction, EnergyFunction] | Tuple[ForceFunction, EnergyFunction, bool]] = ()) ForceManager[source][source]#

Create a ForceManager for a state with the given shape.

Parameters:
  • state_shape – Shape of the state position array, typically (..., dim).

  • gravity – Optional initial gravitational acceleration. Defaults to zeros of shape (dim,).

  • force_functions

    Sequence of callables or tuples. Supported formats:

    • ForceFunc: Applied at particle, no potential energy.

    • (ForceFunc, bool): Boolean specifies if it is a COM force.

    • (ForceFunc, EnergyFunc): Includes potential energy function.

    • (ForceFunc, EnergyFunc, bool): Includes energy and COM specifier.

    Signature of ForceFunc: (pos, state, system) -> (Force, Torque) Signature of EnergyFunc: (pos, state, system) -> Energy

    Supported formats for force_functions items: - func -> (func, None, False) - (func,) -> (func, None, False) - (func, bool) -> (func, None, bool) - (func, energy) -> (func, energy, False) - (func, energy, bool) -> (func, energy, bool) - (func, None, bool) -> (func, None, bool)

static add_force(state: State, system: System, force: jax.Array, *, is_com: bool = False) System[source][source]#

Accumulate an external force to be applied on the next apply call for all particles.

Parameters:
  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

  • force (jax.Array) – External force to be added to all particles in the current order.

  • is_com (bool, optional) – If True, force is applied to Center of Mass (no induced torque). If False (default), force is applied to Particle Position (induces torque).

static add_force_at(state: State, system: System, force: jax.Array, idx: jax.Array, *, is_com: bool = False) System[source][source]#

Add an external force to particles with ID=idx.

Parameters:
  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

  • force (jax.Array) – External force to be added to particles with ID=idx.

  • idx (jax.Array) – ID of the particles affected by the external force.

  • is_com (bool, optional) – If True, force is applied to Center of Mass (no induced torque). If False (default), force is applied to Particle Position (induces torque).

static add_torque(state: State, system: System, torque: jax.Array) System[source][source]#

Accumulate an external torque to be applied on the next apply call for all particles.

Parameters:
  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

  • torque (jax.Array) – External torque to be added to all particles in the current order..

static add_torque_at(state: State, system: System, torque: jax.Array, idx: jax.Array) System[source][source]#

Add an external torque to particles with ID=idx.

Parameters:
  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

  • torque (jax.Array) – External torque to be added to particles with ID=idx.

  • idx (jax.Array) – ID of the particles affected by the external force.

static apply(state: State, system: System) Tuple[State, System][source][source]#

Accumulate managed per-particle contributions on top of collider/contact forces, then perform final clump aggregation + broadcast.

Parameters:
  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

Returns:

The updated state and system after one time step.

Return type:

Tuple[State, System]

static compute_potential_energy(state: State, system: System) jax.Array[source][source]#

Compute the total potential energy of the system.

Notes

  • The energy of clump members is divided by the number of spheres in the clump.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A scalar JAX array representing the total potential energy of each particle.

Return type:

jax.Array

class jaxdem.DeformableParticleContainer(elements: Array | None, edges: Array | None, element_adjacency: Array | None, element_adjacency_edges: Array | None, elements_ID: Array | None, edges_ID: Array | None, element_adjacency_ID: Array | None, num_bodies: int, initial_body_contents: Array | None, initial_element_measures: Array | None, initial_edge_lengths: Array | None, initial_bending: Array | None, inv_ref_shape: Array | None, inv_ref_tet_shape: Array | None, initial_tet_volumes: Array | None, em: Array | None, ec: Array | None, eb: Array | None, el: Array | None, gamma: Array | None, lame_lambda: Array | None, lame_mu: Array | None, use_tetrahedral_svk: bool)[source]#

Bases: object

Registry holding topology and reference configuration for deformable particles.

This container manages the mesh connectivity (elements, edges, etc.) and reference properties (initial measures, contents, lengths, angles) required to compute forces. It supports both 3D (volumetric bodies bounded by triangles) and 2D (planar bodies bounded by line segments).

Indices in elements, edges, etc. correspond to the ID of particles in State.

The general form of the deformable particle potential energy per particle is:

\[&E_K = E_{K,measure} + E_{K,content} + E_{K,bending} + E_{K,edge} + E_{K,strain}\]

Definitions per Dimension:

  • 3D: Measure ($mathcal{M}$) is Face Area; Content ($mathcal{C}$) is Volume; Elements are Triangles.

  • 2D: Measure ($mathcal{M}$) is Segment Length; Content ($mathcal{C}$) is Enclosed Area; Elements are Segments.

Strain Energy (StVK) on Elements (Triangles): .. math:

W = A_0 \cdot \left( \mu \mathrm{tr}(E^2) + \frac{\lambda}{2} (\mathrm{tr} E)^2 \right)

All mesh properties are concatenated along axis=0.

elements: Array | None#

Array of vertex indices forming the boundary elements. Shape: (M, 3) for 3D (Triangles) or (M, 2) for 2D (Segments). Indices refer to the particle unique ID corresponding to the State.pos array.

edges: Array | None#

(E, 2). Each row contains the indices of the two vertices forming the edge. Note: In 2D, the set of edges often overlaps with the set of elements (segments).

Type:

Array of vertex indices forming the unique edges (wireframe). Shape

element_adjacency: Array | None#

(A, 2). Each row contains the indices of the two elements sharing a connection.

Type:

Array of element adjacency pairs (for bending/dihedral angles). Shape

element_adjacency_edges: Array | None#

Array of vertex IDs forming the shared edge for each adjacency. Shape: (A, 2).

elements_ID: Array | None#

(M,). elements_ID[i] == k means element i belongs to body k.

Type:

Array of body IDs for each boundary element. Shape

edges_ID: Array | None#

(E,). edges_ID[e] == k means edge e belongs to body k.

Type:

Array of body IDs for each unique edge. Shape

element_adjacency_ID: Array | None#

(A,). element_adjacency_ID[a] == k means adjacency a belongs to body k.

Type:

Array of body IDs for each adjacency (bending hinge). Shape

num_bodies: int#

(K,).

Type:

Total number of distinct deformable bodies in the container. Shape

initial_body_contents: Array | None#

(K,). Represents Volume in 3D or Area in 2D.

Type:

Array of reference (stress-free) bulk content for each body. Shape

initial_element_measures: Array | None#

(M,). Represents Area in 3D or Length in 2D.

Type:

Array of reference (stress-free) measures for each element. Shape

initial_edge_lengths: Array | None#

(E,).

Type:

Array of reference (stress-free) lengths for each unique edge. Shape

initial_bending: Array | None#

(A,). Represents Dihedral Angle in 3D or Vertex Angle in 2D.

Type:

Array of reference (stress-free) bending angles for each adjacency. Shape

inv_ref_shape: Array | None#

Inverse of the reference shape matrix for each element. Shape: (M, 2, 2) for triangles, or (M, 1, 1) for segments. Used to compute the deformation gradient F or Green strain E.

inv_ref_tet_shape: Array | None#

Inverse of the reference shape matrix for tetrahedra formed by each boundary triangle and the corresponding body center. Shape: (M, 3, 3).

initial_tet_volumes: Array | None#

Reference volumes for tetrahedra formed by each boundary triangle and the corresponding body center. Shape: (M,).

em: Array | None#

(K,). (Controls Area stiffness in 3D; Length stiffness in 2D).

Type:

Measure elasticity coefficient (Modulus) for each body. Shape

ec: Array | None#

(K,). (Controls Volume stiffness in 3D; Area stiffness in 2D).

Type:

Content elasticity coefficient (Modulus) for each body. Shape

eb: Array | None#

(K,).

Type:

Bending elasticity coefficient (Rigidity) for each body. Shape

el: Array | None#

(K,).

Type:

Edge length elasticity coefficient (Modulus) for each body. Shape

gamma: Array | None#

(K,).

Type:

Surface/Line tension coefficient for each body. Shape

lame_lambda: Array | None#

(K,).

Type:

First Lamé parameter for StVK model. Shape

lame_mu: Array | None#

(K,).

Type:

Second Lamé parameter (Shear Modulus) for StVK model. Shape

use_tetrahedral_svk: bool#

If True, compute StVK strain energy on tetrahedra formed by each boundary triangle and the mesh center of its body (3D only). If False, use the existing shell-like element StVK model.

static create(vertices: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, elements: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, edges: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, element_adjacency: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, elements_ID: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, edges_ID: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, element_adjacency_ID: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, initial_element_measures: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, initial_body_contents: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, initial_bending: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, initial_edge_lengths: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, em: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, ec: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, eb: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, el: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, gamma: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, lame_lambda: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, lame_mu: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, use_tetrahedral_svk: bool = False) DeformableParticleContainer[source][source]#

Factory method to create a new DeformableParticleContainer.

Calculates initial geometric properties (areas, volumes, bending angles, and edge lengths) from the provided vertices if they are not explicitly provided.

static merge(c1: DeformableParticleContainer, c2: DeformableParticleContainer) DeformableParticleContainer[source][source]#

Merges two DeformableParticleContainer instances.

static add(container: DeformableParticleContainer, vertices: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, elements_ID: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, elements: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, initial_element_measures: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, element_adjacency_ID: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, element_adjacency: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, initial_bending: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, edges_ID: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, edges: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, initial_edge_lengths: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, initial_body_contents: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, em: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, ec: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, eb: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, el: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, gamma: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, lame_lambda: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, lame_mu: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, use_tetrahedral_svk: bool = False) DeformableParticleContainer[source][source]#

Factory method to add bodies to a container.

static compute_potential_energy(pos: jax.Array, state: State, _system: System, container: DeformableParticleContainer) Tuple[jax.Array, Dict[str, jax.Array]][source][source]#
static create_force_function(container: DeformableParticleContainer) ForceFunction[source][source]#
static create_force_energy_functions(container: DeformableParticleContainer) Tuple[ForceFunction, EnergyFunction][source][source]#

Modules

analysis

Post-processing / analysis utilities.

colliders

Collision-detection interfaces and implementations.

domains

Simulation domains and boundary-condition implementations.

factory

The factory defines and instantiates specific simulation components.

forces

Force-law interfaces and concrete implementations.

integrators

Time-integration interfaces and implementations.

material_matchmakers

Material mix rules and implementations.

materials

Interface for defining materials and the MaterialTable.

minimizers

Energy-minimizer interfaces and implementations.

rl

JaxDEM reinforcement learning (RL) module.

state

Defines the simulation State.

system

Defines the simulation configuration and the tooling for driving the simulation.

utils

Utility functions used to set up simulations and analyze the output.

writers

Interface for defining data writers.