jaxdem.state#

Defines the simulation State.

Functions

get_real_pos(pos_c, pos_p, q)

Classes

State(pos_c, pos_p, vel, force, q, angVel, ...)

Represents the complete simulation state for a system of N particles in 2D or 3D.

jaxdem.state.get_real_pos(pos_c: Array, pos_p: Array, q: Quaternion) Array[source][source]#
final class jaxdem.state.State(pos_c: Array, pos_p: Array, vel: Array, force: Array, q: Quaternion, angVel: Array, torque: Array, rad: Array, volume: Array, mass: Array, inertia: Array, ID: Array, mat_id: Array, species_id: Array, fixed: Array)[source]#

Bases: object

Represents the complete simulation state for a system of N particles in 2D or 3D.

Notes

State is designed to support various data layouts:

  • Single snapshot:

    pos.shape = (N, dim) for particle properties (e.g., pos, vel, force), and (N,) for scalar properties (e.g., rad, mass). In this case, batch_size is 1.

  • Batched states:

    pos.shape = (B, N, dim) for particle properties, and (B, N) for scalar properties. Here, B is the batch dimension (batch_size = pos.shape[0]).

  • Trajectories of a single simulation:

    pos.shape = (T, N, dim) for particle properties, and (T, N) for scalar properties. Here, T is the trajectory dimension.

  • Trajectories of batched states:

    pos.shape = (B, T_1, T_2, …, T_k, N, dim) for particle properties, and (B, T_1, T_2, …, T_k, N) for scalar properties.

    • The first dimension (i.e., pos.shape[0]) is always interpreted as the batch dimension (`B`).

    • All preceding leading dimensions (T_1, T_2, … T_k) are interpreted as trajectory dimensions

      and they are flattened at save time if there is more than 1 trajectory dimension.

The class is final and cannot be subclassed.

Example

Creating a simple 2D state for 4 particles:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>> import jax
>>>
>>> positions = jnp.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])
>>> state = jdem.State.create(pos=positions)
>>>
>>> print(f"Number of particles (N): {state.N}")
>>> print(f"Spatial dimension (dim): {state.dim}")
>>> print(f"Positions: {state.pos}")

Creating a batched state:

>>> batched_state = jax.vmap(lambda _: jdem.State.create(pos=positions))(jnp.arange(10))
>>>
>>> print(f"Batch size: {batched_state.batch_size}")  # 10
>>> print(f"Positions shape: {batched_state.pos.shape}")
pos_c: Array#

Array of particle center of mass positions. Shape is (…, N, dim).

pos_p: Array#

Vector relative to the center of mass (pos_p = pos - pos_c) in the principal reference frame. This data is constant. Shape is (…, N, dim).

vel: Array#

Array of particle center of mass velocities. Shape is (…, N, dim).

force: Array#

Array of particle forces. Shape is (…, N, dim).

q: Quaternion#

Quaternion representing the orientation of the particle.

angVel: Array#

Array of particle center of mass angular velocities. Shape is (…, N, 1 | 3) depending on 2D or 3D simulations.

torque: Array#

Array of particle torques. Shape is (…, N, 1 | 3) depending on 2D or 3D simulations.

rad: Array#

Array of particle radii. Shape is (…, N).

volume: Array#

Array of particle volumes (or areas if 2D). Shape is (…, N).

mass: Array#

Array of particle masses. Shape is (…, N).

inertia: Array#

Inertia tensor in the principal axis frame (…, N, 1 | 3) depending on 2D or 3D simulations.

ID: Array#

Array of unique particle identifiers. Shape is (…, N).

mat_id: Array#

Array of material IDs for each particle. Shape is (…, N).

species_id: Array#

Array of species IDs for each particle. Shape is (…, N).

fixed: Array#

Boolean array indicating if a particle is fixed (immobile). Shape is (…, N).

property N: int[source]#

Number of particles in the state.

property dim: int[source]#

Spatial dimension of the simulation.

property shape: Tuple[int, ...][source]#

Number of particles in the state.

property batch_size: int[source]#

Return the batch size of the state.

property pos: Array[source]#

Returns the position of each sphere in the state. pos_c is the center of mass pos_p is the vector relative to the center of mass such that pos = pos_c = pos_p in the principal reference frame. Therefore, pos_p needs to be transformed to the lab frame.

property is_valid: bool[source]#

Check if the internal representation of the State is consistent.

Verifies that:

  • The spatial dimension (dim) is either 2 or 3.

  • All position-like arrays (pos, vel, force) have the same shape.

  • All scalar-per-particle arrays (rad, mass, ID, mat_id, species_id, fixed) have a shape consistent with pos.shape[:-1].

Raises:

AssertionError – If any shape inconsistency is found.

static create(pos: ArrayLike, *, vel: ArrayLike | None = None, force: ArrayLike | None = None, q: Quaternion | None | ArrayLike | None = None, angVel: ArrayLike | None = None, torque: ArrayLike | None = None, rad: ArrayLike | None = None, volume: ArrayLike | None = None, mass: ArrayLike | None = None, inertia: ArrayLike | None = None, ID: ArrayLike | None = None, mat_id: ArrayLike | None = None, species_id: ArrayLike | None = None, fixed: ArrayLike | None = None, mat_table: 'MaterialTable' | None = None) State[source][source]#

Factory method to create a new State instance.

This method handles default values and ensures consistent array shapes for all state attributes.

Parameters:
  • pos (jax.typing.ArrayLike) – Initial positions of particles. Expected shape: (…, N, dim).

  • vel (jax.typing.ArrayLike or None, optional) – Initial velocities of particles. If None, defaults to zeros. Expected shape: (…, N, dim).

  • force (jax.typing.ArrayLike or None, optional) – Initial forces on particles. If None, defaults to zeros. Expected shape: (…, N, dim).

  • q (Quaternion or array-like, optional) – Initial particle orientations. If None, defaults to identity quaternions. Accepted shapes: quaternion objects or arrays of shape (…, N, 4) with components ordered as (w, x, y, z).

  • angVel (jax.typing.ArrayLike or None, optional) – Initial angular velocities of particles. If None, defaults to zeros. Expected shape: (…, N, 1) in 2D or (…, N, 3) in 3D.

  • torque (jax.typing.ArrayLike or None, optional) – Initial torques on particles. If None, defaults to zeros. Expected shape: (…, N, 1) in 2D or (…, N, 3) in 3D.

  • rad (jax.typing.ArrayLike or None, optional) – Radii of particles. If None, defaults to ones. Expected shape: (…, N).

  • volume (jax.typing.ArrayLike or None, optional) – Volume of particles (or area in 2D). If None, defaults to hypersphere volumes of the radii. Expected shape: (…, N).

  • mass (jax.typing.ArrayLike or None, optional) – Masses of particles. If None, defaults to ones. Ignored when mat_table is provided. Expected shape: (…, N).

  • inertia (jax.typing.ArrayLike or None, optional) – Moments of inertia in the principal axes frame. If None, defaults to solid disks (2D) or spheres (3D). Expected shape: (…, N, 1) in 2D or (…, N, 3) in 3D.

  • ID (jax.typing.ArrayLike or None, optional) – Unique identifiers for particles. If None, defaults to jnp.arange(). Expected shape: (…, N).

  • mat_id (jax.typing.ArrayLike or None, optional) – Material IDs for particles. If None, defaults to zeros. Expected shape: (…, N).

  • species_id (jax.typing.ArrayLike or None, optional) – Species IDs for particles. If None, defaults to zeros. Expected shape: (…, N).

  • fixed (jax.typing.ArrayLike or None, optional) – Boolean array indicating fixed particles. If None, defaults to all False. Expected shape: (…, N).

  • mat_table (MaterialTable or None, optional) – Optional material table providing per-material densities. When provided, the mass argument is ignored and particle masses are computed from density and particle volume.

Returns:

A new State instance with all attributes correctly initialized and shaped.

Return type:

State

Raises:

ValueError – If the created State is not valid.

Example

Creating a 3D state for 5 particles:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> my_pos = jnp.array([[0.,0.,0.], [1.,0.,0.], [0.,1.,0.], [0.,0.,1.], [1.,1.,1.]])
>>> my_rad = jnp.array([0.5, 0.5, 0.5, 0.5, 0.5])
>>> my_mass = jnp.array([1.0, 1.0, 1.0, 1.0, 1.0])
>>>
>>> state_5_particles = jdem.State.create(pos=my_pos, rad=my_rad, mass=my_mass)
>>> print(f"Shape of positions: {state_5_particles.pos.shape}")
>>> print(f"Radii: {state_5_particles.rad}")
static merge(state1: State, state2: State) State[source][source]#

Merges two State instances into a single new State.

This method concatenates the particles from state2 onto state1. Particle IDs in state2 are shifted to ensure uniqueness in the merged state.

state1State

The first State instance. Its particles will appear first in the merged state.

state2State

The second State instance. Its particles will be appended to the first.

State

A new State instance containing all particles from both input states.

AssertionError

If either input state is invalid, or if there is a mismatch in spatial dimension (dim) or batch size (batch_size).

ValueError

If the resulting merged state is somehow invalid.

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> state_a = jdem.State.create(pos=jnp.array([[0.0, 0.0], [1.0, 1.0]]), ID=jnp.array([0, 1]))
>>> state_b = jdem.State.create(pos=jnp.array([[2.0, 2.0], [3.0, 3.0]]), ID=jnp.array([0, 1]))
>>> merged_state = jdem.State.merge(state_a, state_b)
>>>
>>> print(f"Merged state N: {merged_state.N}")  # Expected: 4
>>> print(f"Merged state positions:
{merged_state.pos}”)
>>> print(f"Merged state IDs: {merged_state.ID}")  # Expected: [0, 1, 2, 3]
static add(state: State, pos: ArrayLike, *, vel: ArrayLike | None = None, force: ArrayLike | None = None, q: Quaternion | None | ArrayLike | None = None, angVel: ArrayLike | None = None, torque: ArrayLike | None = None, rad: ArrayLike | None = None, volume: ArrayLike | None = None, mass: ArrayLike | None = None, inertia: ArrayLike | None = None, ID: ArrayLike | None = None, mat_id: ArrayLike | None = None, species_id: ArrayLike | None = None, fixed: ArrayLike | None = None, mat_table: 'MaterialTable' | None = None) State[source][source]#

Adds new particles to an existing State instance, returning a new State.

Parameters:
  • state (State) – The existing State to which particles will be added.

  • pos (jax.typing.ArrayLike) – Positions of the new particle(s). Shape (…, N_new, dim).

  • vel (jax.typing.ArrayLike or None, optional) – Velocities of the new particle(s). Defaults to zeros.

  • force (jax.typing.ArrayLike or None, optional) – Forces of the new particle(s). Defaults to zeros.

  • q (Quaternion or array-like, optional) – Initial orientations of the new particle(s). Defaults to identity quaternions.

  • angVel (jax.typing.ArrayLike or None, optional) – Angular velocities of the new particle(s). Defaults to zeros.

  • torque (jax.typing.ArrayLike or None, optional) – Torques of the new particle(s). Defaults to zeros.

  • rad (jax.typing.ArrayLike or None, optional) – Radii of the new particle(s). Defaults to ones.

  • volume (jax.typing.ArrayLike or None, optional) – Volume of the new particle(s) (or area in 2D). Defaults to hypersphere volumes of the radii.

  • mass (jax.typing.ArrayLike or None, optional) – Masses of the new particle(s). Defaults to ones. Ignored when a mat_table is provided.

  • inertia (jax.typing.ArrayLike or None, optional) – Moments of inertia of the new particle(s). Defaults to solid disks (2D) or spheres (3D).

  • ID (jax.typing.ArrayLike or None, optional) – IDs of the new particle(s). If None, new IDs are generated.

  • mat_id (jax.typing.ArrayLike or None, optional) – Material IDs of the new particle(s). Defaults to zeros.

  • species_id (jax.typing.ArrayLike or None, optional) – Species IDs of the new particle(s). Defaults to zeros.

  • fixed (jax.typing.ArrayLike or None, optional) – Fixed status of the new particle(s). Defaults to all False.

  • mat_table (MaterialTable or None, optional) – Optional material table providing per-material densities. When provided, masses are computed from density and particle volume.

Returns:

A new State instance containing all particles from the original state plus the newly added particles.

Return type:

State

Raises:
  • ValueError – If the created new particle state or the merged state is invalid.

  • AssertionError – If batch size or dimension mismatch between existing state and new particles.

Example

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> # Initial state with 4 particles
>>> state = jdem.State.create(pos=jnp.zeros((4, 2)))
>>> print(f"Original state N: {state.N}, IDs: {state.ID}")
>>>
>>> # Add a single new particle
>>> state_with_added_particle = jdem.State.add(
...     state,
...     pos=jnp.array([[10.0, 10.0]]),
...     rad=jnp.array([0.5]),
...     mass=jnp.array([2.0]),
... )
>>> print(f"New state N: {state_with_added_particle.N}, IDs: {state_with_added_particle.ID}")
>>> print(f"New particle position: {state_with_added_particle.pos[-1]}")
>>>
>>> # Add multiple new particles
>>> state_multiple_added = jdem.State.add(
...     state,
...     pos=jnp.array([[10.0, 10.0], [11.0, 11.0], [12.0, 12.0]]),
... )
>>> print(f"State with multiple added N: {state_multiple_added.N}, IDs: {state_multiple_added.ID}")
static stack(states: Sequence[State]) State[source][source]#

Concatenates a sequence of State snapshots into a trajectory or batch along axis 0.

This method is useful for collecting simulation snapshots over time into a single State object where the leading dimension represents time or when preparing a batched state.

Parameters:

states (Sequence[State]) – A sequence (e.g., list, tuple) of State instances to be stacked.

Returns:

A new State instance where each attribute is a JAX array with an additional leading dimension representing the stacked trajectory. For example, if input pos was (N, dim), output pos will be (T, N, dim).

Return type:

State

Raises:
  • ValueError – If the input states sequence is empty. If the stacked State is invalid.

  • AssertionError – If any input state is invalid, or if there is a mismatch in spatial dimension (dim), batch size (batch_size), or number of particles (N) between the states in the sequence.

Notes

  • No ID shifting is performed because the leading axis represents time (or another batch dimension), not new particles.

Example

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> # Create a sequence of 3 simple 2D snapshots
>>> snapshot1 = jdem.State.create(pos=jnp.array([[0.,0.], [1.,1.]]), vel=jnp.array([[0.1,0.], [0.0,0.1]]))
>>> snapshot2 = jdem.State.create(pos=jnp.array([[0.1,0.], [1.,1.1]]), vel=jnp.array([[0.1,0.], [0.0,0.1]]))
>>> snapshot3 = jdem.State.create(pos=jnp.array([[0.2,0.], [1.,1.2]]), vel=jnp.array([[0.1,0.], [0.0,0.1]]))
>>>
>>> trajectory_state = State.stack([snapshot1, snapshot2, snapshot3])
>>>
>>> print(f"Trajectory positions shape: {trajectory_state.pos.shape}") # Expected: (3, 2, 2)
>>> print(f"Positions at time step 0:\n{trajectory_state.pos[0]}")
>>> print(f"Positions at time step 1:\n{trajectory_state.pos[1]}")
static add_clump(state: State, pos: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, vel: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, force: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, q: Quaternion | None | Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = None, angVel: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, torque: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, rad: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, volume: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, mass: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, inertia: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, mat_id: int | None = None, species_id: int | None = None, fixed: int | None = None) State[source][source]#

Adds new clump to an existing State instance, returning a new State.

Parameters:
  • state (State) – The existing State to which particles will be added.

  • pos (jax.typing.ArrayLike) – Positions of the new particle(s). Shape (…, N_new, dim).

  • vel (jax.typing.ArrayLike or None, optional) – Velocities of the new particle(s). Defaults to zeros.

  • force (jax.typing.ArrayLike or None, optional) – Forces of the new particle(s). Defaults to zeros.

  • q (Quaternion or array-like, optional) – Initial orientations of the new particle(s). Defaults to identity quaternions.

  • angVel (jax.typing.ArrayLike or None, optional) – Angular velocities of the new particle(s). Defaults to zeros.

  • torque (jax.typing.ArrayLike or None, optional) – Torques of the new particle(s). Defaults to zeros.

  • rad (jax.typing.ArrayLike or None, optional) – Radii of the new particle(s). Defaults to ones.

  • volume (jax.typing.ArrayLike or None, optional) – Volume of the new particle(s) (or area in 2D). Defaults to hypersphere volumes of the radii.

  • mass (jax.typing.ArrayLike or None, optional) – Masses of the new particle(s). Defaults to ones.

  • inertia (jax.typing.ArrayLike or None, optional) – Moments of inertia of the new particle(s). Defaults to solid disks (2D) or spheres (3D).

  • ID (jax.typing.ArrayLike or None, optional) – IDs of the new particle(s). If None, new IDs are generated.

  • mat_id (jax.typing.ArrayLike or None, optional) – Material IDs of the new particle(s). Defaults to zeros.

  • species_id (jax.typing.ArrayLike or None, optional) – Species IDs of the new particle(s). Defaults to zeros.

  • fixed (jax.typing.ArrayLike or None, optional) – Fixed status of the new particle(s). Defaults to all False.

Returns:

A new State instance containing all particles from the original state plus the newly added particles.

Return type:

State

Raises:
  • ValueError – If the created new particle state or the merged state is invalid.

  • AssertionError – If batch size or dimension mismatch between existing state and new particles.

Example

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> # Initial state with 4 particles
>>> state = jdem.State.create(pos=jnp.zeros((4, 2)))
>>> print(f"Original state N: {state.N}, IDs: {state.ID}")
>>>
>>> # Add a single new particle
>>> state_with_added_particle = jdem.State.add(
...     state,
...     pos=jnp.array([[10.0, 10.0]]),
...     rad=jnp.array([0.5]),
...     mass=jnp.array([2.0]),
... )
>>> print(f"New state N: {state_with_added_particle.N}, IDs: {state_with_added_particle.ID}")
>>> print(f"New particle position: {state_with_added_particle.pos[-1]}")
>>>
>>> # Add multiple new particles
>>> state_multiple_added = jdem.State.add(
...     state,
...     pos=jnp.array([[10.0, 10.0], [11.0, 11.0], [12.0, 12.0]]),
... )
>>> print(f"State with multiple added N: {state_multiple_added.N}, IDs: {state_multiple_added.ID}")