jaxdem.state#
Defines the simulation State.
Functions
|
Classes
|
Represents the complete simulation state for a system of N particles in 2D or 3D. |
- jaxdem.state.get_real_pos(pos_c: Array, pos_p: Array, q: Quaternion) Array[source][source]#
- final class jaxdem.state.State(pos_c: Array, pos_p: Array, vel: Array, force: Array, q: Quaternion, angVel: Array, torque: Array, rad: Array, volume: Array, mass: Array, inertia: Array, ID: Array, mat_id: Array, species_id: Array, fixed: Array)[source]#
Bases:
objectRepresents the complete simulation state for a system of N particles in 2D or 3D.
Notes
State is designed to support various data layouts:
- Single snapshot:
pos.shape = (N, dim) for particle properties (e.g., pos, vel, force), and (N,) for scalar properties (e.g., rad, mass). In this case, batch_size is 1.
- Batched states:
pos.shape = (B, N, dim) for particle properties, and (B, N) for scalar properties. Here, B is the batch dimension (batch_size = pos.shape[0]).
- Trajectories of a single simulation:
pos.shape = (T, N, dim) for particle properties, and (T, N) for scalar properties. Here, T is the trajectory dimension.
- Trajectories of batched states:
pos.shape = (B, T_1, T_2, …, T_k, N, dim) for particle properties, and (B, T_1, T_2, …, T_k, N) for scalar properties.
The first dimension (i.e., pos.shape[0]) is always interpreted as the batch dimension (`B`).
- All preceding leading dimensions (T_1, T_2, … T_k) are interpreted as trajectory dimensions
and they are flattened at save time if there is more than 1 trajectory dimension.
The class is final and cannot be subclassed.
Example
Creating a simple 2D state for 4 particles:
>>> import jaxdem as jdem >>> import jax.numpy as jnp >>> import jax >>> >>> positions = jnp.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) >>> state = jdem.State.create(pos=positions) >>> >>> print(f"Number of particles (N): {state.N}") >>> print(f"Spatial dimension (dim): {state.dim}") >>> print(f"Positions: {state.pos}")
Creating a batched state:
>>> batched_state = jax.vmap(lambda _: jdem.State.create(pos=positions))(jnp.arange(10)) >>> >>> print(f"Batch size: {batched_state.batch_size}") # 10 >>> print(f"Positions shape: {batched_state.pos.shape}")
- pos_c: Array#
Array of particle center of mass positions. Shape is (…, N, dim).
- pos_p: Array#
Vector relative to the center of mass (pos_p = pos - pos_c) in the principal reference frame. This data is constant. Shape is (…, N, dim).
- vel: Array#
Array of particle center of mass velocities. Shape is (…, N, dim).
- force: Array#
Array of particle forces. Shape is (…, N, dim).
- q: Quaternion#
Quaternion representing the orientation of the particle.
- angVel: Array#
Array of particle center of mass angular velocities. Shape is (…, N, 1 | 3) depending on 2D or 3D simulations.
- torque: Array#
Array of particle torques. Shape is (…, N, 1 | 3) depending on 2D or 3D simulations.
- rad: Array#
Array of particle radii. Shape is (…, N).
- volume: Array#
Array of particle volumes (or areas if 2D). Shape is (…, N).
- mass: Array#
Array of particle masses. Shape is (…, N).
- inertia: Array#
Inertia tensor in the principal axis frame (…, N, 1 | 3) depending on 2D or 3D simulations.
- ID: Array#
Array of unique particle identifiers. Shape is (…, N).
- mat_id: Array#
Array of material IDs for each particle. Shape is (…, N).
- species_id: Array#
Array of species IDs for each particle. Shape is (…, N).
- fixed: Array#
Boolean array indicating if a particle is fixed (immobile). Shape is (…, N).
- property pos: Array[source]#
Returns the position of each sphere in the state. pos_c is the center of mass pos_p is the vector relative to the center of mass such that pos = pos_c = pos_p in the principal reference frame. Therefore, pos_p needs to be transformed to the lab frame.
- property is_valid: bool[source]#
Check if the internal representation of the State is consistent.
Verifies that:
The spatial dimension (dim) is either 2 or 3.
All position-like arrays (pos, vel, force) have the same shape.
All scalar-per-particle arrays (rad, mass, ID, mat_id, species_id, fixed) have a shape consistent with pos.shape[:-1].
- Raises:
AssertionError – If any shape inconsistency is found.
- static create(pos: ArrayLike, *, vel: ArrayLike | None = None, force: ArrayLike | None = None, q: Quaternion | None | ArrayLike | None = None, angVel: ArrayLike | None = None, torque: ArrayLike | None = None, rad: ArrayLike | None = None, volume: ArrayLike | None = None, mass: ArrayLike | None = None, inertia: ArrayLike | None = None, ID: ArrayLike | None = None, mat_id: ArrayLike | None = None, species_id: ArrayLike | None = None, fixed: ArrayLike | None = None, mat_table: 'MaterialTable' | None = None) State[source][source]#
Factory method to create a new
Stateinstance.This method handles default values and ensures consistent array shapes for all state attributes.
- Parameters:
pos (jax.typing.ArrayLike) – Initial positions of particles. Expected shape: (…, N, dim).
vel (jax.typing.ArrayLike or None, optional) – Initial velocities of particles. If None, defaults to zeros. Expected shape: (…, N, dim).
force (jax.typing.ArrayLike or None, optional) – Initial forces on particles. If None, defaults to zeros. Expected shape: (…, N, dim).
q (Quaternion or array-like, optional) – Initial particle orientations. If None, defaults to identity quaternions. Accepted shapes: quaternion objects or arrays of shape (…, N, 4) with components ordered as (w, x, y, z).
angVel (jax.typing.ArrayLike or None, optional) – Initial angular velocities of particles. If None, defaults to zeros. Expected shape: (…, N, 1) in 2D or (…, N, 3) in 3D.
torque (jax.typing.ArrayLike or None, optional) – Initial torques on particles. If None, defaults to zeros. Expected shape: (…, N, 1) in 2D or (…, N, 3) in 3D.
rad (jax.typing.ArrayLike or None, optional) – Radii of particles. If None, defaults to ones. Expected shape: (…, N).
volume (jax.typing.ArrayLike or None, optional) – Volume of particles (or area in 2D). If None, defaults to hypersphere volumes of the radii. Expected shape: (…, N).
mass (jax.typing.ArrayLike or None, optional) – Masses of particles. If None, defaults to ones. Ignored when mat_table is provided. Expected shape: (…, N).
inertia (jax.typing.ArrayLike or None, optional) – Moments of inertia in the principal axes frame. If None, defaults to solid disks (2D) or spheres (3D). Expected shape: (…, N, 1) in 2D or (…, N, 3) in 3D.
ID (jax.typing.ArrayLike or None, optional) – Unique identifiers for particles. If None, defaults to
jnp.arange(). Expected shape: (…, N).mat_id (jax.typing.ArrayLike or None, optional) – Material IDs for particles. If None, defaults to zeros. Expected shape: (…, N).
species_id (jax.typing.ArrayLike or None, optional) – Species IDs for particles. If None, defaults to zeros. Expected shape: (…, N).
fixed (jax.typing.ArrayLike or None, optional) – Boolean array indicating fixed particles. If None, defaults to all False. Expected shape: (…, N).
mat_table (MaterialTable or None, optional) – Optional material table providing per-material densities. When provided, the mass argument is ignored and particle masses are computed from density and particle volume.
- Returns:
A new State instance with all attributes correctly initialized and shaped.
- Return type:
- Raises:
ValueError – If the created State is not valid.
Example
Creating a 3D state for 5 particles:
>>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> my_pos = jnp.array([[0.,0.,0.], [1.,0.,0.], [0.,1.,0.], [0.,0.,1.], [1.,1.,1.]]) >>> my_rad = jnp.array([0.5, 0.5, 0.5, 0.5, 0.5]) >>> my_mass = jnp.array([1.0, 1.0, 1.0, 1.0, 1.0]) >>> >>> state_5_particles = jdem.State.create(pos=my_pos, rad=my_rad, mass=my_mass) >>> print(f"Shape of positions: {state_5_particles.pos.shape}") >>> print(f"Radii: {state_5_particles.rad}")
- static merge(state1: State, state2: State) State[source][source]#
Merges two
Stateinstances into a single newState.This method concatenates the particles from state2 onto state1. Particle IDs in state2 are shifted to ensure uniqueness in the merged state.
- state1State
The first State instance. Its particles will appear first in the merged state.
- state2State
The second State instance. Its particles will be appended to the first.
- State
A new State instance containing all particles from both input states.
- AssertionError
If either input state is invalid, or if there is a mismatch in spatial dimension (dim) or batch size (batch_size).
- ValueError
If the resulting merged state is somehow invalid.
>>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> state_a = jdem.State.create(pos=jnp.array([[0.0, 0.0], [1.0, 1.0]]), ID=jnp.array([0, 1])) >>> state_b = jdem.State.create(pos=jnp.array([[2.0, 2.0], [3.0, 3.0]]), ID=jnp.array([0, 1])) >>> merged_state = jdem.State.merge(state_a, state_b) >>> >>> print(f"Merged state N: {merged_state.N}") # Expected: 4 >>> print(f"Merged state positions:
- {merged_state.pos}”)
>>> print(f"Merged state IDs: {merged_state.ID}") # Expected: [0, 1, 2, 3]
- static add(state: State, pos: ArrayLike, *, vel: ArrayLike | None = None, force: ArrayLike | None = None, q: Quaternion | None | ArrayLike | None = None, angVel: ArrayLike | None = None, torque: ArrayLike | None = None, rad: ArrayLike | None = None, volume: ArrayLike | None = None, mass: ArrayLike | None = None, inertia: ArrayLike | None = None, ID: ArrayLike | None = None, mat_id: ArrayLike | None = None, species_id: ArrayLike | None = None, fixed: ArrayLike | None = None, mat_table: 'MaterialTable' | None = None) State[source][source]#
Adds new particles to an existing
Stateinstance, returning a new State.- Parameters:
state (State) – The existing State to which particles will be added.
pos (jax.typing.ArrayLike) – Positions of the new particle(s). Shape (…, N_new, dim).
vel (jax.typing.ArrayLike or None, optional) – Velocities of the new particle(s). Defaults to zeros.
force (jax.typing.ArrayLike or None, optional) – Forces of the new particle(s). Defaults to zeros.
q (Quaternion or array-like, optional) – Initial orientations of the new particle(s). Defaults to identity quaternions.
angVel (jax.typing.ArrayLike or None, optional) – Angular velocities of the new particle(s). Defaults to zeros.
torque (jax.typing.ArrayLike or None, optional) – Torques of the new particle(s). Defaults to zeros.
rad (jax.typing.ArrayLike or None, optional) – Radii of the new particle(s). Defaults to ones.
volume (jax.typing.ArrayLike or None, optional) – Volume of the new particle(s) (or area in 2D). Defaults to hypersphere volumes of the radii.
mass (jax.typing.ArrayLike or None, optional) – Masses of the new particle(s). Defaults to ones. Ignored when a mat_table is provided.
inertia (jax.typing.ArrayLike or None, optional) – Moments of inertia of the new particle(s). Defaults to solid disks (2D) or spheres (3D).
ID (jax.typing.ArrayLike or None, optional) – IDs of the new particle(s). If None, new IDs are generated.
mat_id (jax.typing.ArrayLike or None, optional) – Material IDs of the new particle(s). Defaults to zeros.
species_id (jax.typing.ArrayLike or None, optional) – Species IDs of the new particle(s). Defaults to zeros.
fixed (jax.typing.ArrayLike or None, optional) – Fixed status of the new particle(s). Defaults to all False.
mat_table (MaterialTable or None, optional) – Optional material table providing per-material densities. When provided, masses are computed from density and particle volume.
- Returns:
A new State instance containing all particles from the original state plus the newly added particles.
- Return type:
- Raises:
ValueError – If the created new particle state or the merged state is invalid.
AssertionError – If batch size or dimension mismatch between existing state and new particles.
Example
>>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> # Initial state with 4 particles >>> state = jdem.State.create(pos=jnp.zeros((4, 2))) >>> print(f"Original state N: {state.N}, IDs: {state.ID}") >>> >>> # Add a single new particle >>> state_with_added_particle = jdem.State.add( ... state, ... pos=jnp.array([[10.0, 10.0]]), ... rad=jnp.array([0.5]), ... mass=jnp.array([2.0]), ... ) >>> print(f"New state N: {state_with_added_particle.N}, IDs: {state_with_added_particle.ID}") >>> print(f"New particle position: {state_with_added_particle.pos[-1]}") >>> >>> # Add multiple new particles >>> state_multiple_added = jdem.State.add( ... state, ... pos=jnp.array([[10.0, 10.0], [11.0, 11.0], [12.0, 12.0]]), ... ) >>> print(f"State with multiple added N: {state_multiple_added.N}, IDs: {state_multiple_added.ID}")
- static stack(states: Sequence[State]) State[source][source]#
Concatenates a sequence of
Statesnapshots into a trajectory or batch along axis 0.This method is useful for collecting simulation snapshots over time into a single State object where the leading dimension represents time or when preparing a batched state.
- Parameters:
states (Sequence[State]) – A sequence (e.g., list, tuple) of
Stateinstances to be stacked.- Returns:
A new
Stateinstance where each attribute is a JAX array with an additional leading dimension representing the stacked trajectory. For example, if input pos was (N, dim), output pos will be (T, N, dim).- Return type:
- Raises:
ValueError – If the input states sequence is empty. If the stacked State is invalid.
AssertionError – If any input state is invalid, or if there is a mismatch in spatial dimension (dim), batch size (batch_size), or number of particles (N) between the states in the sequence.
Notes
No ID shifting is performed because the leading axis represents time (or another batch dimension), not new particles.
Example
>>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> # Create a sequence of 3 simple 2D snapshots >>> snapshot1 = jdem.State.create(pos=jnp.array([[0.,0.], [1.,1.]]), vel=jnp.array([[0.1,0.], [0.0,0.1]])) >>> snapshot2 = jdem.State.create(pos=jnp.array([[0.1,0.], [1.,1.1]]), vel=jnp.array([[0.1,0.], [0.0,0.1]])) >>> snapshot3 = jdem.State.create(pos=jnp.array([[0.2,0.], [1.,1.2]]), vel=jnp.array([[0.1,0.], [0.0,0.1]])) >>> >>> trajectory_state = State.stack([snapshot1, snapshot2, snapshot3]) >>> >>> print(f"Trajectory positions shape: {trajectory_state.pos.shape}") # Expected: (3, 2, 2) >>> print(f"Positions at time step 0:\n{trajectory_state.pos[0]}") >>> print(f"Positions at time step 1:\n{trajectory_state.pos[1]}")
- static add_clump(state: State, pos: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, vel: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, force: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, q: Quaternion | None | Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = None, angVel: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, torque: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, rad: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, volume: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, mass: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, inertia: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, mat_id: int | None = None, species_id: int | None = None, fixed: int | None = None) State[source][source]#
Adds new clump to an existing
Stateinstance, returning a new State.- Parameters:
state (State) – The existing State to which particles will be added.
pos (jax.typing.ArrayLike) – Positions of the new particle(s). Shape (…, N_new, dim).
vel (jax.typing.ArrayLike or None, optional) – Velocities of the new particle(s). Defaults to zeros.
force (jax.typing.ArrayLike or None, optional) – Forces of the new particle(s). Defaults to zeros.
q (Quaternion or array-like, optional) – Initial orientations of the new particle(s). Defaults to identity quaternions.
angVel (jax.typing.ArrayLike or None, optional) – Angular velocities of the new particle(s). Defaults to zeros.
torque (jax.typing.ArrayLike or None, optional) – Torques of the new particle(s). Defaults to zeros.
rad (jax.typing.ArrayLike or None, optional) – Radii of the new particle(s). Defaults to ones.
volume (jax.typing.ArrayLike or None, optional) – Volume of the new particle(s) (or area in 2D). Defaults to hypersphere volumes of the radii.
mass (jax.typing.ArrayLike or None, optional) – Masses of the new particle(s). Defaults to ones.
inertia (jax.typing.ArrayLike or None, optional) – Moments of inertia of the new particle(s). Defaults to solid disks (2D) or spheres (3D).
ID (jax.typing.ArrayLike or None, optional) – IDs of the new particle(s). If None, new IDs are generated.
mat_id (jax.typing.ArrayLike or None, optional) – Material IDs of the new particle(s). Defaults to zeros.
species_id (jax.typing.ArrayLike or None, optional) – Species IDs of the new particle(s). Defaults to zeros.
fixed (jax.typing.ArrayLike or None, optional) – Fixed status of the new particle(s). Defaults to all False.
- Returns:
A new State instance containing all particles from the original state plus the newly added particles.
- Return type:
- Raises:
ValueError – If the created new particle state or the merged state is invalid.
AssertionError – If batch size or dimension mismatch between existing state and new particles.
Example
>>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> # Initial state with 4 particles >>> state = jdem.State.create(pos=jnp.zeros((4, 2))) >>> print(f"Original state N: {state.N}, IDs: {state.ID}") >>> >>> # Add a single new particle >>> state_with_added_particle = jdem.State.add( ... state, ... pos=jnp.array([[10.0, 10.0]]), ... rad=jnp.array([0.5]), ... mass=jnp.array([2.0]), ... ) >>> print(f"New state N: {state_with_added_particle.N}, IDs: {state_with_added_particle.ID}") >>> print(f"New particle position: {state_with_added_particle.pos[-1]}") >>> >>> # Add multiple new particles >>> state_multiple_added = jdem.State.add( ... state, ... pos=jnp.array([[10.0, 10.0], [11.0, 11.0], [12.0, 12.0]]), ... ) >>> print(f"State with multiple added N: {state_multiple_added.N}, IDs: {state_multiple_added.ID}")