jaxdem.state#
Defines the simulation State.
Classes
|
Represents the complete simulation state for a system of N particles in 2D or 3D. |
- final class jaxdem.state.State(pos_c: Array, pos_p: Array, vel: Array, force: Array, q: Quaternion, angVel: Array, torque: Array, rad: Array, volume: Array, mass: Array, inertia: Array, clump_ID: Array, deformable_ID: Array, unique_ID: Array, mat_id: Array, species_id: Array, fixed: Array)[source]#
Bases:
objectRepresents the complete simulation state for a system of N particles in 2D or 3D.
Notes
State is designed to support various data layouts:
- Single snapshot:
pos.shape = (N, dim) for particle properties (e.g., pos, vel, force), and (N,) for scalar properties (e.g., rad, mass). In this case, batch_size is 1.
- Batched states:
pos.shape = (B, N, dim) for particle properties, and (B, N) for scalar properties. Here, B is the batch dimension (batch_size = pos.shape[0]).
- Trajectories of a single simulation:
pos.shape = (T, N, dim) for particle properties, and (T, N) for scalar properties. Here, T is the trajectory dimension.
- Trajectories of batched states:
pos.shape = (B, T_1, T_2, …, T_k, N, dim) for particle properties, and (B, T_1, T_2, …, T_k, N) for scalar properties.
The first dimension (i.e., pos.shape[0]) is always interpreted as the batch dimension (`B`).
- All preceding leading dimensions (T_1, T_2, … T_k) are interpreted as trajectory dimensions
and they are flattened at save time if there is more than 1 trajectory dimension.
The class is final and cannot be subclassed.
Example
Creating a simple 2D state for 4 particles:
>>> import jaxdem as jdem >>> import jax.numpy as jnp >>> import jax >>> >>> positions = jnp.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) >>> state = jdem.State.create(pos=positions) >>> >>> print(f"Number of particles (N): {state.N}") >>> print(f"Spatial dimension (dim): {state.dim}") >>> print(f"Positions: {state.pos}")
Creating a batched state:
>>> batched_state = jax.vmap(lambda _: jdem.State.create(pos=positions))(jnp.arange(10)) >>> >>> print(f"Batch size: {batched_state.batch_size}") # 10 >>> print(f"Positions shape: {batched_state.pos.shape}")
- pos_c: Array#
Array of particle center of mass positions. Shape is (…, N, dim).
- pos_p: Array#
Vector relative to the center of mass (pos_p = pos - pos_c) in the principal reference frame. This field should be constant. Shape is (…, N, dim).
- vel: Array#
Array of particle center of mass velocities. Shape is (…, N, dim).
- force: Array#
Array of particle forces. Shape is (…, N, dim).
- q: Quaternion#
Quaternion representing the orientation of the particle.
- angVel: Array#
Array of particle center of mass angular velocities. Shape is (…, N, 1 | 3) depending on 2D or 3D simulations.
- torque: Array#
Array of particle torques. Shape is (…, N, 1 | 3) depending on 2D or 3D simulations.
- rad: Array#
Array of particle radii. Shape is (…, N).
- volume: Array#
Array of particle volumes (or areas if 2D). Shape is (…, N).
- mass: Array#
Array of particle masses. Shape is (…, N).
- inertia: Array#
Inertia tensor in the principal axis frame (…, N, 1 | 3) depending on 2D or 3D simulations.
- clump_ID: Array#
Array of clump identifiers. Bodies with the same clump_ID are treated as part of the same rigid body. Shape is (…, N). IDs need to be between 0 and N.
- deformable_ID: Array#
Array of deformable particle identifiers. Spheres (nodes) with the same deformable_ID are treated as part of the same deformable particle for collision masking purposes. Shape is (…, N). IDs need to be between 0 and N.
- unique_ID: Array#
Array of unique particle identifiers. No ID should be repeated. Shape is (…, N). IDs need to be between 0 and N.
- mat_id: Array#
Array of material IDs for each particle. Shape is (…, N).
- species_id: Array#
Array of species IDs for each particle. Shape is (…, N).
- fixed: Array#
Boolean array indicating if a particle is fixed (immobile). Shape is (…, N).
- property shape: Tuple[int, ...][source]#
Shape of the position array
pos_c, e.g.(N, dim)or(B, N, dim).
- property pos: Array[source]#
Returns the position of each sphere in the state.
pos_cis the center of mass andpos_pis the vector relative to the center of mass in the principal reference frame, such thatpos = pos_c + R(q) @ pos_pwhereR(q)rotatespos_pto the lab frame.
- property is_valid: bool[source]#
Check if the internal representation of the State is consistent.
Verifies that:
The spatial dimension (dim) is either 2 or 3.
All position-like arrays (pos_c, pos_p, vel, force) have the same shape.
All angular-like arrays (angVel, torque, inertia) have the same shape.
All scalar-per-particle arrays (rad, mass, clump_ID, deformable_ID, mat_id, species_id, fixed) have a shape consistent with pos.shape[:-1].
- Raises:
AssertionError – If any shape inconsistency is found.
- static create(pos: ArrayLike, *, pos_p: ArrayLike | None = None, vel: ArrayLike | None = None, force: ArrayLike | None = None, q: Quaternion | None | ArrayLike | None = None, angVel: ArrayLike | None = None, torque: ArrayLike | None = None, rad: ArrayLike | None = None, volume: ArrayLike | None = None, mass: ArrayLike | None = None, inertia: ArrayLike | None = None, clump_ID: ArrayLike | None = None, deformable_ID: ArrayLike | None = None, mat_id: ArrayLike | None = None, species_id: ArrayLike | None = None, fixed: ArrayLike | None = None, mat_table: MaterialTable | None = None) State[source][source]#
Factory method to create a new
Stateinstance.This method handles default values and ensures consistent array shapes for all state attributes.
- Parameters:
pos (jax.typing.ArrayLike) – Array of particle center of mass positions, equivalent to state.pos_c. Expected shape: (…, N, dim).
pos_p (jax.typing.ArrayLike) – Vector relative to the center of mass (pos_p = pos - pos_c) in the principal reference frame. This field should be constant. Shape is (…, N, dim).
vel (jax.typing.ArrayLike or None, optional) – Initial velocities of particles. If None, defaults to zeros. Expected shape: (…, N, dim).
force (jax.typing.ArrayLike or None, optional) – Initial forces on particles. If None, defaults to zeros. Expected shape: (…, N, dim).
q (Quaternion or array-like, optional) – Initial particle orientations. If None, defaults to identity quaternions. Accepted shapes: quaternion objects or arrays of shape (…, N, 4) with components ordered as (w, x, y, z).
angVel (jax.typing.ArrayLike or None, optional) – Initial angular velocities of particles. If None, defaults to zeros. Expected shape: (…, N, 1) in 2D or (…, N, 3) in 3D.
torque (jax.typing.ArrayLike or None, optional) – Initial torques on particles. If None, defaults to zeros. Expected shape: (…, N, 1) in 2D or (…, N, 3) in 3D.
rad (jax.typing.ArrayLike or None, optional) – Radii of particles. If None, defaults to ones. Expected shape: (…, N).
volume (jax.typing.ArrayLike or None, optional) – Volume of particles (or area in 2D). If None, defaults to hypersphere volumes of the radii. Expected shape: (…, N).
mass (jax.typing.ArrayLike or None, optional) – Masses of particles. If None, defaults to ones. Ignored when mat_table is provided. Expected shape: (…, N).
inertia (jax.typing.ArrayLike or None, optional) – Moments of inertia in the principal axes frame. If None, defaults to solid disks (2D) or spheres (3D). Expected shape: (…, N, 1) in 2D or (…, N, 3) in 3D.
clump_ID (jax.typing.ArrayLike or None, optional) – Unique identifiers for clumps. If None, defaults to
jnp.arange(). Expected shape: (…, N).deformable_ID (jax.typing.ArrayLike or None, optional) – Unique identifiers for deformable particles. If None, defaults to
jnp.arange(). Expected shape: (…, N).mat_id (jax.typing.ArrayLike or None, optional) – Material IDs for particles. If None, defaults to zeros. Expected shape: (…, N).
species_id (jax.typing.ArrayLike or None, optional) – Species IDs for particles. If None, defaults to zeros. Expected shape: (…, N).
fixed (jax.typing.ArrayLike or None, optional) – Boolean array indicating fixed particles. If None, defaults to all False. Expected shape: (…, N).
mat_table (MaterialTable or None, optional) – Optional material table providing per-material densities. When provided, the mass argument is ignored and particle masses are computed from density and particle volume.
- Returns:
A new State instance with all attributes correctly initialized and shaped.
- Return type:
- Raises:
ValueError – If the created State is not valid.
Example
Creating a 3D state for 5 particles:
>>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> my_pos = jnp.array([[0.,0.,0.], [1.,0.,0.], [0.,1.,0.], [0.,0.,1.], [1.,1.,1.]]) >>> my_rad = jnp.array([0.5, 0.5, 0.5, 0.5, 0.5]) >>> my_mass = jnp.array([1.0, 1.0, 1.0, 1.0, 1.0]) >>> >>> state_5_particles = jdem.State.create(pos=my_pos, rad=my_rad, mass=my_mass) >>> print(f"Shape of positions: {state_5_particles.pos.shape}") >>> print(f"Radii: {state_5_particles.rad}")
- static merge(state1: State, state2: State | Sequence[State]) State[source][source]#
Merges multiple
Stateinstances into a single newState.This method concatenates the particles from the provided state(s) onto state1. Particle clump_IDs, deformable_IDs, and unique_IDs are shifted to ensure uniqueness across the merged system.
- state1State
The first State instance. Its particles will appear first in the merged state.
- state2State or Sequence[State]
The second State or a list/tuple of State instances to append.
- State
A new State instance containing all particles from both input states.
- AssertionError
If either input state is invalid, or if there is a mismatch in spatial dimension (dim) or batch size (batch_size).
- ValueError
If the resulting merged state is somehow invalid.
>>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> state_a = jdem.State.create(pos=jnp.array([[0.0, 0.0], [1.0, 1.0]]), clump_ID=jnp.array([0, 1])) >>> state_b = jdem.State.create(pos=jnp.array([[2.0, 2.0], [3.0, 3.0]]), clump_ID=jnp.array([0, 1])) >>> merged_state = jdem.State.merge(state_a, [state_b, state_b, state_b]) >>> merged_state = jdem.State.merge(state_a, state_b) >>> >>> print(f"Merged state N: {merged_state.N}") # Expected: 4 >>> print(f"Merged state positions:
- {merged_state.pos}”)
>>> print(f"Merged state clump_IDs: {merged_state.clump_ID}") # Expected: [0, 1, 2, 3]
- static add(state: State, pos: ArrayLike, *, pos_p: ArrayLike | None = None, vel: ArrayLike | None = None, force: ArrayLike | None = None, q: Quaternion | None | ArrayLike | None = None, angVel: ArrayLike | None = None, torque: ArrayLike | None = None, rad: ArrayLike | None = None, volume: ArrayLike | None = None, mass: ArrayLike | None = None, inertia: ArrayLike | None = None, clump_ID: ArrayLike | None = None, deformable_ID: ArrayLike | None = None, mat_id: ArrayLike | None = None, species_id: ArrayLike | None = None, fixed: ArrayLike | None = None, mat_table: MaterialTable | None = None) State[source][source]#
Adds new particles to an existing
Stateinstance, returning a new State.- Parameters:
state (State) – The existing State to which particles will be added.
pos (jax.typing.ArrayLike) – Array of particle center of mass positions, equivalent to state.pos_c. Expected shape: (…, N, dim).
pos_p (jax.typing.ArrayLike) – Vector relative to the center of mass (pos_p = pos - pos_c) in the principal reference frame. This field should be constant. Shape is (…, N, dim).
vel (jax.typing.ArrayLike or None, optional) – Velocities of the new particle(s). Defaults to zeros.
force (jax.typing.ArrayLike or None, optional) – Forces of the new particle(s). Defaults to zeros.
q (Quaternion or array-like, optional) – Initial orientations of the new particle(s). Defaults to identity quaternions.
angVel (jax.typing.ArrayLike or None, optional) – Angular velocities of the new particle(s). Defaults to zeros.
torque (jax.typing.ArrayLike or None, optional) – Torques of the new particle(s). Defaults to zeros.
rad (jax.typing.ArrayLike or None, optional) – Radii of the new particle(s). Defaults to ones.
volume (jax.typing.ArrayLike or None, optional) – Volume of the new particle(s) (or area in 2D). Defaults to hypersphere volumes of the radii.
mass (jax.typing.ArrayLike or None, optional) – Masses of the new particle(s). Defaults to ones. Ignored when a mat_table is provided.
inertia (jax.typing.ArrayLike or None, optional) – Moments of inertia of the new particle(s). Defaults to solid disks (2D) or spheres (3D).
clump_ID (jax.typing.ArrayLike or None, optional) – clump_IDs of the new clump(s). If None, new IDs are generated.
deformable_ID (jax.typing.ArrayLike or None, optional) – Unique identifiers for deformable particles. If None, defaults to
jnp.arange(). Expected shape: (…, N).mat_id (jax.typing.ArrayLike or None, optional) – Material IDs of the new particle(s). Defaults to zeros.
species_id (jax.typing.ArrayLike or None, optional) – Species IDs of the new particle(s). Defaults to zeros.
fixed (jax.typing.ArrayLike or None, optional) – Fixed status of the new particle(s). Defaults to all False.
mat_table (MaterialTable or None, optional) – Optional material table providing per-material densities. When provided, masses are computed from density and particle volume.
- Returns:
A new State instance containing all particles from the original state plus the newly added particles.
- Return type:
- Raises:
ValueError – If the created new particle state or the merged state is invalid.
AssertionError – If batch size or dimension mismatch between existing state and new particles.
Example
>>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> # Initial state with 4 particles >>> state = jdem.State.create(pos=jnp.zeros((4, 2))) >>> print(f"Original state N: {state.N}, clump_IDs: {state.clump_ID}") >>> >>> # Add a single new particle >>> state_with_added_particle = jdem.State.add( ... state, ... pos=jnp.array([[10.0, 10.0]]), ... rad=jnp.array([0.5]), ... mass=jnp.array([2.0]), ... ) >>> print(f"New state N: {state_with_added_particle.N}, clump_IDs: {state_with_added_particle.clump_ID}") >>> print(f"New particle position: {state_with_added_particle.pos[-1]}") >>> >>> # Add multiple new particles >>> state_multiple_added = jdem.State.add( ... state, ... pos=jnp.array([[10.0, 10.0], [11.0, 11.0], [12.0, 12.0]]), ... ) >>> print(f"State with multiple added N: {state_multiple_added.N}, clump_IDs: {state_multiple_added.clump_ID}")
- static stack(states: Sequence[State]) State[source][source]#
Concatenates a sequence of
Statesnapshots into a trajectory or batch along axis 0.This method is useful for collecting simulation snapshots over time into a single State object where the leading dimension represents time or when preparing a batched state.
- Parameters:
states (Sequence[State]) – A sequence (e.g., list, tuple) of
Stateinstances to be stacked.- Returns:
A new
Stateinstance where each attribute is a JAX array with an additional leading dimension representing the stacked trajectory. For example, if input pos was (N, dim), output pos will be (T, N, dim).- Return type:
- Raises:
ValueError – If the input states sequence is empty. If the stacked State is invalid.
AssertionError – If any input state is invalid, or if there is a mismatch in spatial dimension (dim), batch size (batch_size), or number of particles (N) between the states in the sequence.
Notes
No clump_ID shifting is performed because the leading axis represents time (or another batch dimension), not new particles.
Example
>>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> # Create a sequence of 3 simple 2D snapshots >>> snapshot1 = jdem.State.create(pos=jnp.array([[0.,0.], [1.,1.]]), vel=jnp.array([[0.1,0.], [0.0,0.1]])) >>> snapshot2 = jdem.State.create(pos=jnp.array([[0.1,0.], [1.,1.1]]), vel=jnp.array([[0.1,0.], [0.0,0.1]])) >>> snapshot3 = jdem.State.create(pos=jnp.array([[0.2,0.], [1.,1.2]]), vel=jnp.array([[0.1,0.], [0.0,0.1]])) >>> >>> trajectory_state = State.stack([snapshot1, snapshot2, snapshot3]) >>> >>> print(f"Trajectory positions shape: {trajectory_state.pos.shape}") # Expected: (3, 2, 2) >>> print(f"Positions at time step 0:\n{trajectory_state.pos[0]}") >>> print(f"Positions at time step 1:\n{trajectory_state.pos[1]}")
- static unstack(state: State) list[State][source][source]#
Split a stacked/batched
Statealong the leading axis into a Python list.This is the convenient inverse of
State.stack():If stacked = State.stack([s0, s1, …]), then State.unstack(stacked) returns [s0, s1, …].
Notes
The split is performed along axis 0 (the leading axis).
A single snapshot State (e.g. pos.shape == (N, dim)) cannot be unstacked with this method, because axis 0 would refer to particles, not snapshots.
- static add_clump(state: State, pos: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, pos_p: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, vel: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, force: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, q: Quaternion | None | Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = None, angVel: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, torque: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, rad: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, volume: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, mass: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, inertia: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, deformable_ID: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, mat_id: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, species_id: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, fixed: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None) State[source][source]#
Adds a new clump consisting of multiple spheres to an existing State. Rigid body properties (velocity, mass, material, etc.) are broadcasted to all spheres in the new clump. The only per sphere properties that vary in a rigid body are pos_c, pos_p, and rad.
TO DO: broadcast the quaternion
- Parameters:
state (State) – The existing State to which particles will be added.
pos (jax.typing.ArrayLike) – Array of particle center of mass positions, equivalent to state.pos_c. Expected shape: (…, N, dim).
pos_p (jax.typing.ArrayLike) – Vector relative to the center of mass (pos_p = pos - pos_c) in the principal reference frame. This field should be constant. Shape is (…, N, dim).
vel (jax.typing.ArrayLike or None, optional) – Velocities of the new particle(s). Defaults to zeros.
force (jax.typing.ArrayLike or None, optional) – Forces of the new particle(s). Defaults to zeros.
q (Quaternion or array-like, optional) – Initial orientations of the new particle(s). Defaults to identity quaternions.
angVel (jax.typing.ArrayLike or None, optional) – Angular velocities of the new particle(s). Defaults to zeros.
torque (jax.typing.ArrayLike or None, optional) – Torques of the new particle(s). Defaults to zeros.
rad (jax.typing.ArrayLike or None, optional) – Radii of the new particle(s). Defaults to ones.
volume (jax.typing.ArrayLike or None, optional) – Volume of the new particle(s) (or area in 2D). Defaults to hypersphere volumes of the radii.
mass (jax.typing.ArrayLike or None, optional) – Masses of the new particle(s). Defaults to ones.
inertia (jax.typing.ArrayLike or None, optional) – Moments of inertia of the new particle(s). Defaults to solid disks (2D) or spheres (3D).
deformable_ID (jax.typing.ArrayLike or None, optional) – Unique identifiers for deformable particles. If None, defaults to
jnp.arange(). Expected shape: (…, N).mat_id (jax.typing.ArrayLike or None, optional) – Material IDs of the new particle(s). Defaults to zeros.
species_id (jax.typing.ArrayLike or None, optional) – Species IDs of the new particle(s). Defaults to zeros.
fixed (jax.typing.ArrayLike or None, optional) – Fixed status of the new particle(s). Defaults to all False.
- Returns:
A new State instance containing all particles from the original state plus the newly added particles.
- Return type: