Note
Go to the end to download the full example code.
The Simulation State#
This example focuses on the jaxdem.state.State object,
a core component of JaxDEM that holds all information about the particles
in a simulation.
JaxDEM stores particle data using a Structure-of-Arrays (SoA) architecture, making it efficient for JAX’s vectorised and parallel computations. This layout also simplifies handling trajectories and batched simulations without complex code changes.
Let’s explore how to create, modify, and extend the simulation state effectively.
State Creation#
We’ll start by creating a simple 2D state representing a single particle
located at the origin. By default, jaxdem.state.State.create()
initializes non-specified attributes (like velocity, radius, mass) with
sensible default values.
import jax
import jaxdem as jdem
import jax.numpy as jnp
state = jdem.State.create(pos=jnp.array([[0.0, 0.0]]))
print(f"Dimension of state: {state.dim}")
print(f"Initial position: {state.pos}")
Dimension of state: 2
Initial position: [[0. 0.]]
To create a 3D state, simply pass 3D coordinates. JaxDEM infers the dimension from the position data (only 2D and 3D are supported).
state = jdem.State.create(pos=jnp.array([[0.0, 0.0, 0.0]]))
print(f"Dimension of state: {state.dim}")
print(f"Initial position: {state.pos}")
Dimension of state: 3
Initial position: [[0. 0. 0.]]
Understanding Positions: pos, pos_c, and pos_p#
An important detail: state.pos is not a stored field. It is a
computed property defined as pos = pos_c + R(q) @ pos_p, where
R(q) is the rotation given by the particle’s quaternion orientation.
The stored fields are:
pos_c— the center-of-mass position of each particle (or clump).pos_p— the offset from the center of mass in the principal (body) frame. For simple spherespos_pis zero, sopos == pos_c.
For clumps (rigid bodies made of multiple spheres), every sphere in the
same clump shares the same center of mass position pos_c, orientation q,
velocity vel, angular velocity ang_vel, force, torque, mass,
volume volume, inertia, fixed, and clump_id.
The per-sphere fields that can vary within a rigid clump are:
pos_p— the body-frame offset relative to the COMrad— individual sphere radiusThe individual ID fields:
mat_id,species_id, andbond_id.
This deliberate design allows vectorized operations over all spheres without branching on clump membership.
Modifying State Attributes#
We have two primary ways to set or modify particle attributes:
Direct assignment: You can assign new JAX arrays to attributes like state.vel. This is flexible but requires you to ensure shape consistency.
state.vel = jnp.ones_like(state.pos)
print(state.vel)
[[1. 1. 1.]]
Note that because we are dealing with JAX arrays, doing something like
state.vel[i] = jnp.asarray([1, 2, 3], dtype=float)
will result in an error. The correct way of doing this is
i = 0
state.vel = state.vel.at[i].set(jnp.asarray([1, 2, 3], dtype=float))
print(state.vel)
[[1. 2. 3.]]
However, this is inefficient and not recommended. Always prefer vectorised operations.
Constructor arguments: This is generally the safer approach, as the
jaxdem.state.State.create()constructor automatically validates shapes and types, ensuring consistency across all attributes.
state = jdem.State.create(pos=jnp.zeros((1, 2)), vel=jnp.ones((1, 2)))
print(state.vel)
[[1. 1.]]
Fixed (Immobile) Particles#
The boolean field state.fixed marks particles that should not move.
The integrator multiplies velocity updates by (1 - fixed), so
fixed particles keep zero velocity regardless of the forces acting on
them. This is useful for walls, obstacles, or boundary particles.
state = jdem.State.create(
pos=jnp.array([[0.0, 0.0], [2.0, 0.0]]),
rad=jnp.array([1.0, 1.0]),
fixed=jnp.array([True, False]),
)
print("Fixed mask:", state.fixed)
Fixed mask: [ True False]
Identifier Fields#
Each particle carries several integer identifiers:
clump_id— groups particles into rigid bodies (see Clumps (Rigid Bodies)). Particles with the sameclump_idnever interact via contact forces and move as one body. By default every particle has a uniqueclump_id.bond_id— connectivity masking array (see Deformable Particles). For each particle, it stores the unique IDs (unique_id) of the neighbor particles it is connected to. Interactions between connected particles are disabled (masked out). It has shape(N, max_num_neighbors)and is padded with-1.mat_id— indexes into theMaterialTableto look up material properties (density, Young’s modulus, …).species_id— selects which force law applies to a pair when using aForceRouter(see Force Models).unique_id— a per-particle unique identifier (never repeated).
print("clump_id :", state.clump_id)
print("bond_id :", state.bond_id)
print("mat_id :", state.mat_id)
print("species_id:", state.species_id)
clump_id : [0 1]
bond_id : [[-1]
[-1]]
mat_id : [0 0]
species_id: [0 0]
Setting Up Connections with bond_id#
For each particle, the bond_id array stores the unique IDs (unique_id) of other particles
it is connected to.
By default, these connections are used to ignore contact/non-bonded interactions in the collider
(e.g. to prevent connected particles from colliding with each other). This behavior is controlled
by the interact_same_bond_id parameter in the system creation options (see jaxdem.system.System.create()).
Setting interact_same_bond_id=True allows particles connected via bond_id to still experience contact forces.
This connectivity masking is particularly useful in combination with bonded models (where interactions are permanent), such as deformable particle models and cohesive/bonded networks. For more details on configuring bonded interactions, refer to the Deformable Particles and the Colliders.
When calling jaxdem.state.State.create(), you can define connections by passing
a list of lists (which can have uneven lengths) for the bond_id argument.
JaxDEM automatically symmetrizes these connections (i.e. if particle A connects to B, then
B connects to A) and pads the array with -1 up to the maximum number of connections.
If a particle has no connections, its corresponding row will only contain -1.
If no connections are provided at all, bond_id defaults to a shape of (N, 1) filled with -1.
# Create a state with 4 particles:
# - Particle 0 connects to 1 and 2
# - Particle 1 connects to 0
# - Particle 2 connects to 0
# - Particle 3 has no connections
positions = jnp.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [2.0, 2.0]])
state_bonded = jdem.State.create(
pos=positions,
bond_id=[[1, 2], [0], [0], []]
)
print("Bond IDs for each particle:\n", state_bonded.bond_id)
# Default behavior when bond_id is not passed
state_no_bonds = jdem.State.create(pos=positions)
print("Default bond IDs (no bonds):\n", state_no_bonds.bond_id)
Bond IDs for each particle:
[[ 1 2]
[ 0 -1]
[ 0 -1]
[-1 -1]]
Default bond IDs (no bonds):
[[-1]
[-1]
[-1]
[-1]]
Extending the State#
Working directly with SoA structures can sometimes feel less intuitive
than Array-of-Structures (AoS) for adding and modifying individual particles. To simplify
this, JaxDEM provides utility methods like jaxdem.state.State.add().
jaxdem.state.State.add() allows you to append new particles to an
existing state, automatically assigning unique clump_ids and checking for dimension
consistency.
state = jdem.State.create(pos=jnp.array([[0.0, 0.0]]), rad=jnp.array([0.5]))
print(f"Initial state (N={state.N}, clump_ids={state.clump_id}):\npos={state.pos}")
state = jdem.State.add(
state,
pos=jnp.array([[1.0, 1.0]]),
vel=2 * jnp.ones((1, 2)),
rad=10 * jnp.ones(1),
)
print(
f"\nState after addition (N={state.N}, clump_ids={state.clump_id}):\npos={state.pos}"
)
print(f"New particle velocity: {state.vel[-1]}")
print(f"New particle radius: {state.rad[-1]}")
Initial state (N=1, clump_ids=[0]):
pos=[[0. 0.]]
State after addition (N=2, clump_ids=[0 1]):
pos=[[0. 0.]
[1. 1.]]
New particle velocity: [2. 2.]
New particle radius: 10.0
You can also add multiple particles at once by providing arrays of the
appropriate shape. jaxdem.state.State.add() will ensure the dimensions
of the new particles match the existing state.
state = jdem.State.add(
state,
pos=jnp.array([[2.0, 0.0], [0.0, 2.0]]),
vel=jnp.zeros((2, 2)),
rad=jnp.array([0.8, 0.3]),
clump_id=jnp.array([2, 3]),
)
print(
f"\nState after adding multiple particles (N={state.N}, clump_ids={state.clump_id}):\n{state.pos}"
)
State after adding multiple particles (N=4, clump_ids=[0 1 2 3]):
[[0. 0.]
[1. 1.]
[2. 0.]
[0. 2.]]
Note that we provided explicit clump_id values here.
jaxdem.state.State.add() adds jnp.max(state.clump_id) to
the provided IDs to avoid overlaps. The resulting sequence is not
guaranteed to be contiguous, but this is perfectly valid.
Merging Two States#
To combine two State objects, use
jaxdem.state.State.merge(). It concatenates the particles from
the second state onto the first — useful for assembling complex initial
configurations from smaller parts.
state_a = jdem.State.create(
pos=jnp.array([[0.0, 0.0], [1.0, 1.0]]),
)
state_b = jdem.State.create(
jnp.array([[2.0, 2.0], [3.0, 3.0], [5.0, 2.0]]),
)
state = jdem.State.merge(state_a, state_b)
print(f"State A (N={state_a.N}, clump_ids={state_a.clump_id}):\npos={state_a.pos}")
print(f"State B (N={state_b.N}, clump_ids={state_b.clump_id}):\npos={state_b.pos}")
print(f"Merged state (N={state.N}, clump_ids={state.clump_id}):\npos={state.pos}")
State A (N=2, clump_ids=[0 1]):
pos=[[0. 0.]
[1. 1.]]
State B (N=3, clump_ids=[2 3 4]):
pos=[[2. 2.]
[3. 3.]
[5. 2.]]
Merged state (N=5, clump_ids=[0 1 2 3 4]):
pos=[[0. 0.]
[1. 1.]
[2. 2.]
[3. 3.]
[5. 2.]]
Stacking States for Trajectories or Batches#
One of the features that makes JaxDEM special is its ability to handle batched states. Batches can be interpreted as trajectories (multiple snapshots over time) or as independent simulations (multiple distinct initial conditions).
This is useful for performance. JaxDEM is optimised for throughput: if your GPU is not saturated, you are leaving performance on the table. A common DEM task is running parameter sweeps. JaxDEM lets you run many independent simulations in parallel, potentially finishing all of them in the time it would take for just one, until the GPU is fully utilised.
Furthermore, trajectory support means you don’t have to interrupt the GPU for I/O (e.g., saving state to disk). You can accumulate a full trajectory in memory and save everything at the end, which often gives much better performance at the cost of a bit more memory.
To manage simulation trajectories or perform batched simulations,
jaxdem.state.State.stack() is available. It takes a sequence of
jaxdem.state.State snapshots and concatenates them along a new
leading axis. This creates a multi-dimensional state where the first axis
can represent time steps, batch elements, or other high-level groupings.
Note that stacking does not shift particle clump_ids, as it assumes the
particles are the same entities across the stacked dimension.
jaxdem.state.State.stack() makes sure shapes are consistent.
snapshot1 = jdem.State.create(pos=jnp.array([[0.0, 0.0]]), rad=jnp.array([2.0]))
snapshot2 = jdem.State.create(pos=jnp.array([[0.1, 0.0]]), vel=jnp.array([[0.1, 0.0]]))
snapshot3 = jdem.State.create(pos=jnp.array([[0.2, 0.0]]), mass=jnp.array([3.3]))
batched_state = jdem.State.stack([snapshot1, snapshot2, snapshot3])
print(f"Shape of stacked positions (B, N, dim): {batched_state.pos.shape}")
print(f"Batch size: {batched_state.batch_size}")
Shape of stacked positions (B, N, dim): (3, 1, 2)
Batch size: 3
Another way of creating batch states is using Jax’s vmap:
batched_state = jax.vmap(
lambda i: jdem.State.create(
i
* jnp.ones(
(1, 2),
)
)
)(jnp.arange(4))
print(f"Shape of stacked positions (B, N, dim): {batched_state.pos.shape}")
print(f"Batch size: {batched_state.batch_size}")
print(f"Position at batch 0: {batched_state.pos[0]}")
print(f"Position at batch 1: {batched_state.pos[1]}")
print(f"Position at batch 2: {batched_state.pos[2]}")
Shape of stacked positions (B, N, dim): (4, 1, 2)
Batch size: 4
Position at batch 0: [[0. 0.]]
Position at batch 1: [[1. 1.]]
Position at batch 2: [[2. 2.]]
A more realistic way in which you could encounter a batched state is the following:
def initialize(i: jax.Array) -> tuple[jdem.State, jdem.System]:
state = jdem.State.create(i * jnp.ones((4, 2)))
system = jdem.System.create(state.shape)
return state, system
N_batches = 10
state, system = jax.vmap(initialize)(jnp.arange(N_batches))
Then, to run this simulation:
state, system = system.step(state, system, n=10)
print(f"Shape of positions (B, N, dim): {state.pos.shape}")
Shape of positions (B, N, dim): (10, 4, 2)
Note that system can change over time. Therefore, each state needs to have its own system.
Trajectories of Batches#
JaxDEM’s state handling capabilities extend beyond just batches or single trajectories. We can also accumulate trajectories of batched states.
This is useful for scenarios like parameter sweeps, where you run multiple independent simulations (a batch) and want to capture their full time evolution (a trajectory) without frequent I/O. It allows highly efficient data collection.
jaxdem.writers.VTKWriter.save() understands these
multi-dimensional states.
By convention, when dealing with State attributes of shape (…, N, dim):
For a single snapshot (no batch, no trajectory), the shape is
(N, dim).For a batched state (a single snapshot across multiple independent simulations), the shape is
(B, N, dim).For a single trajectory (multiple snapshots over time of a single simulation), the shape is
(T, N, dim).For a trajectory of batches (multiple snapshots over time of multiple parallel simulations), the shape is
(T, B, N, dim).
In JaxDEM, the batch dimension B (if present) is always located at shape[-3] (the axis just before the particle dimension N at shape[-2]). Hence, batch_size returns shape[-3] when ndim >= 3, which correctly yields B for both (B, N, dim) and (T, B, N, dim) shapes.
When collecting trajectories (via trajectory_rollout()), each snapshot is
stacked along the first axis (axis 0), producing a state of shape
(T, B, N, dim) for batched trajectories.
jaxdem.writers.VTKWriter.save() understands these layouts. By
default (trajectory=False) all leading axes are treated as independent
batches. Pass trajectory=True to tell the writer which axis is time
(trajectory_axis, default 0); the writer swaps that axis to the front,
keeps it as T, and flattens any remaining leading axes into a single
batch axis B, yielding (T, B, N, dim) internally.
batched_state = jdem.State.stack([batched_state, batched_state, batched_state])
print(f"Shape of stacked positions (T, B, N, dim): {batched_state.pos.shape}")
print(f"Batch size: {batched_state.batch_size}")
Shape of stacked positions (T, B, N, dim): (3, 4, 1, 2)
Batch size: 4
Following the example of the previous section, you might encounter a trajectory of batches in the following way:
N_batches = 9
state, system = jax.vmap(initialize)(jnp.arange(N_batches))
state, system, (state_traj, system_traj) = system.trajectory_rollout(
state, system, n=10
)
print(f"Shape of positions (T, B, N, dim): {state_traj.pos.shape}")
Shape of positions (T, B, N, dim): (10, 9, 4, 2)
Utilities#
JaxDEM includes utility functions in jaxdem.utils for
quickly setting up simulations. For example, you can create a state
with randomised attributes:
from jaxdem import utils as utils
state = utils.random_state(dim=3, N=10)
print(state)
State(pos_c=Array([[9.69947998, 3.1518186 , 6.6109639 ],
[9.91386812, 2.82521542, 3.55488773],
[5.25257116, 3.2260835 , 0.58677741],
[6.15111094, 9.54725682, 8.38766458],
[7.92837724, 1.78128245, 3.73091685],
[7.69204925, 8.30547677, 3.51584082],
[7.63565438, 9.26869417, 0.50223628],
[4.82317485, 2.7366294 , 7.91500943],
[3.35842191, 9.16958235, 2.76594463],
[1.19191249, 7.09289908, 7.46735741]], dtype=float64), pos_p=Array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float64), vel=Array([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]], dtype=float64), force=Array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float64), q=Quaternion(w=Array([[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.]], dtype=float64), xyz=Array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float64)), ang_vel=Array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float64), torque=Array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float64), rad=Array([10., 10., 10., 10., 10., 10., 10., 10., 10., 10.], dtype=float64), volume=Array([4188.79020479, 4188.79020479, 4188.79020479, 4188.79020479,
4188.79020479, 4188.79020479, 4188.79020479, 4188.79020479,
4188.79020479, 4188.79020479], dtype=float64), mass=Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float64), inertia=Array([[40., 40., 40.],
[40., 40., 40.],
[40., 40., 40.],
[40., 40., 40.],
[40., 40., 40.],
[40., 40., 40.],
[40., 40., 40.],
[40., 40., 40.],
[40., 40., 40.],
[40., 40., 40.]], dtype=float64), clump_id=Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int64), bond_id=Array([[-1],
[-1],
[-1],
[-1],
[-1],
[-1],
[-1],
[-1],
[-1],
[-1]], dtype=int64), unique_id=Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int64), mat_id=Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int64), species_id=Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int64), fixed=Array([False, False, False, False, False, False, False, False, False,
False], dtype=bool), _pos_p_rot=Array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float64))
Total running time of the script: (0 minutes 7.611 seconds)