jaxdem.state#
Defines the simulation State.
Classes
|
Represents the complete simulation state for a system of N particles in 2D or 3D. |
- final class jaxdem.state.State(pos: Array, vel: Array, accel: Array, rad: Array, mass: Array, ID: Array, mat_id: Array, species_id: Array, fixed: Array)[source][source]#
Bases:
object
Represents the complete simulation state for a system of N particles in 2D or 3D.
Notes
State is designed to support various data layouts:
- Single snapshot:
pos.shape = (N, dim) for particle properties (e.g., pos, vel, accel), and (N,) for scalar properties (e.g., rad, mass). In this case, batch_size is 1.
- Batched states:
pos.shape = (B, N, dim) for particle properties, and (B, N) for scalar properties. Here, B is the batch dimension (batch_size = pos.shape[0]).
- Trajectories of a single simulation:
pos.shape = (T, N, dim) for particle properties, and (T, N) for scalar properties. Here, T is the trajectory dimension.
- Trajectories of batched states:
pos.shape = (B, T_1, T_2, …, T_k, N, dim) for particle properties, and (B, T_1, T_2, …, T_k, N) for scalar properties.
The first dimension (i.e., pos.shape[0]) is always interpreted as the batch dimension (`B`).
- All preceding leading dimensions (T_1, T_2, … T_k) are interpreted as trajectory dimensions
and they are flattened at save time if there is more than 1 trajectory dimension.
The class is final and cannot be subclassed.
Example
Creating a simple 2D state for 4 particles:
>>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> positions = jnp.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) >>> state = jdem.State.create(pos=positions) >>> >>> print(f"Number of particles (N): {state.N}") >>> print(f"Spatial dimension (dim): {state.dim}") >>> print(f"Positions:\n{state.pos}")
Creating a batched state:
>>> batched_state = jax.vmap(lambda _: State.create(pos=positions))(jnp.arange(10)) >>> >>> print(f"Batch size: {batched_state.batch_size}") # 10 >>> print(f"Positions shape: {batched_state.pos.shape}")
- pos: Array#
Array of particle positions. Shape is (…, N, dim).
- vel: Array#
Array of particle velocities. Shape is (…, N, dim).
- accel: Array#
Array of particle accelerations. Shape is (…, N, dim).
- rad: Array#
Array of particle radii. Shape is (…, N).
- mass: Array#
Array of particle masses. Shape is (…, N).
- ID: Array#
Array of unique particle identifiers. Shape is (…, N).
- mat_id: Array#
Array of material IDs for each particle. Shape is (…, N).
- species_id: Array#
Array of species IDs for each particle. Shape is (…, N).
- fixed: Array#
Boolean array indicating if a particle is fixed (immobile). Shape is (…, N).
- property is_valid: bool[source]#
Check if the internal representation of the State is consistent.
Verifies that:
The spatial dimension (dim) is either 2 or 3.
All position-like arrays (pos, vel, accel) have the same shape.
All scalar-per-particle arrays (rad, mass, ID, mat_id, species_id, fixed) have a shape consistent with pos.shape[:-1].
- Raises:
AssertionError – If any shape inconsistency is found.
- static create(pos: Array | ndarray | bool | number | bool | int | float | complex, *, vel: Array | ndarray | bool | number | bool | int | float | complex | None = None, accel: Array | ndarray | bool | number | bool | int | float | complex | None = None, rad: Array | ndarray | bool | number | bool | int | float | complex | None = None, mass: Array | ndarray | bool | number | bool | int | float | complex | None = None, ID: Array | ndarray | bool | number | bool | int | float | complex | None = None, mat_id: Array | ndarray | bool | number | bool | int | float | complex | None = None, species_id: Array | ndarray | bool | number | bool | int | float | complex | None = None, fixed: Array | ndarray | bool | number | bool | int | float | complex | None = None) State [source][source]#
Factory method to create a new
State
instance.This method handles default values and ensures consistent array shapes for all state attributes.
- Parameters:
pos (jax.typing.ArrayLike) – Initial positions of particles. Expected shape: (…, N, dim).
vel (jax.typing.ArrayLike or None, optional) – Initial velocities of particles. If None, defaults to zeros. Expected shape: (…, N, dim).
accel (jax.typing.ArrayLike or None, optional) – Initial accelerations of particles. If None, defaults to zeros. Expected shape: (…, N, dim).
rad (jax.typing.ArrayLike or None, optional) – Radii of particles. If None, defaults to ones. Expected shape: (…, N).
mass (jax.typing.ArrayLike or None, optional) – Masses of particles. If None, defaults to ones. Expected shape: (…, N).
ID (jax.typing.ArrayLike or None, optional) – Unique identifiers for particles. If None, defaults to
jnp.arange(N)()
. Expected shape: (…, N).mat_id (jax.typing.ArrayLike or None, optional) – Material IDs for particles. If None, defaults to zeros. Expected shape: (…, N).
species_id (jax.typing.ArrayLike or None, optional) – Species IDs for particles. If None, defaults to zeros. Expected shape: (…, N).
fixed (jax.typing.ArrayLike or None, optional) – Boolean array indicating fixed particles. If None, defaults to all False. Expected shape: (…, N).
- Returns:
A new State instance with all attributes correctly initialized and shaped.
- Return type:
- Raises:
ValueError – If the created State is not valid.
Example
Creating a 3D state for 5 particles:
>>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> my_pos = jnp.array([[0.,0.,0.], [1.,0.,0.], [0.,1.,0.], [0.,0.,1.], [1.,1.,1.]]) >>> my_rad = jnp.array([0.5, 0.5, 0.5, 0.5, 0.5]) >>> my_mass = jnp.array([1.0, 1.0, 1.0, 1.0, 1.0]) >>> >>> state_5_particles = jdem.State.create(pos=my_pos, rad=my_rad, mass=my_mass) >>> print(f"Shape of positions: {state_5_particles.pos.shape}") >>> print(f"Radii: {state_5_particles.rad}")
- static merge(state1: State, state2: State) State [source][source]#
Merges two
State
instances into a single newState
.This method concatenates the particles from state2 onto state1. Particle IDs in state2 are shifted to ensure uniqueness in the merged state.
- Parameters:
- Returns:
A new State instance containing all particles from both input states.
- Return type:
- Raises:
AssertionError – If either input state is invalid, or if there is a mismatch in spatial dimension (dim) or batch size (batch_size).
ValueError – If the resulting merged state is somehow invalid.
Example
>>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> state_a = jdem.create(pos=jnp.array([[0.,0.], [1.,1.]]), ID=jnp.array([0, 1])) >>> state_b = jdem.State.create(pos=jnp.array([[2.,2.], [3.,3.]]), ID=jnp.array([0, 1])) >>> merged_state = State.merge(state_a, state_b) >>> >>> print(f"Merged state N: {merged_state.N}") # Expected: 4 >>> print(f"Merged state positions:\n{merged_state.pos}") >>> print(f"Merged state IDs: {merged_state.ID}") # Expected: [0, 1, 2, 3]
- static add(state: State, pos: Array | ndarray | bool | number | bool | int | float | complex, *, vel: Array | ndarray | bool | number | bool | int | float | complex | None = None, accel: Array | ndarray | bool | number | bool | int | float | complex | None = None, rad: Array | ndarray | bool | number | bool | int | float | complex | None = None, mass: Array | ndarray | bool | number | bool | int | float | complex | None = None, ID: Array | ndarray | bool | number | bool | int | float | complex | None = None, mat_id: Array | ndarray | bool | number | bool | int | float | complex | None = None, species_id: Array | ndarray | bool | number | bool | int | float | complex | None = None, fixed: Array | ndarray | bool | number | bool | int | float | complex | None = None) State [source][source]#
Adds new particles to an existing
State
instance, returning a new State.- Parameters:
state (State) – The existing State to which particles will be added.
pos (jax.typing.ArrayLike) – Positions of the new particle(s). Shape (…, N_new, dim).
vel (jax.typing.ArrayLike or None, optional) – Velocities of the new particle(s). Defaults to zeros.
accel (jax.typing.ArrayLike or None, optional) – Accelerations of the new particle(s). Defaults to zeros.
rad (jax.typing.ArrayLike or None, optional) – Radii of the new particle(s). Defaults to ones.
mass (jax.typing.ArrayLike or None, optional) – Masses of the new particle(s). Defaults to ones.
ID (jax.typing.ArrayLike or None, optional) – IDs of the new particle(s). If None, new IDs are generated.
mat_id (jax.typing.ArrayLike or None, optional) – Material IDs of the new particle(s). Defaults to zeros.
species_id (jax.typing.ArrayLike or None, optional) – Species IDs of the new particle(s). Defaults to zeros.
fixed (jax.typing.ArrayLike or None, optional) – Fixed status of the new particle(s). Defaults to all False.
- Returns:
A new State instance containing all particles from the original state plus the newly added particles.
- Return type:
- Raises:
ValueError – If the created new particle state or the merged state is invalid.
AssertionError – If batch size or dimension mismatch between existing state and new particles.
Example
>>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> # Initial state with 4 particles >>> state = jdem.State.create(jnp.zeros((4, 2))) >>> print(f"Original state N: {state.N}, IDs: {state.ID}") >>> >>> # Add a single new particle >>> state_with_added_particle = state.add(state, ... pos=jnp.array([[10., 10.]]), ... rad=jnp.array([0.5]), ... mass=jnp.array([2.0])) >>> print(f"New state N: {state_with_added_particle.N}, IDs: {state_with_added_particle.ID}") >>> print(f"New particle position: {state_with_added_particle.pos[-1]}")
Adding multiple new particles:
>>> state_multiple_added = state.add(state, ... pos=jnp.array([[10., 10.], [11., 11.], [12., 12.]])) >>> print(f"State with multiple added N: {state_multiple_added.N}, IDs: {state_multiple_added.ID}")
- static stack(states: Sequence[State]) State [source][source]#
Concatenates a sequence of
State
snapshots into a trajectory along axis 0.This method is useful for collecting simulation snapshots over time into a single State object where the leading dimension represents time.
- Parameters:
states (Sequence[State]) – A sequence (e.g., list, tuple) of
State
instances to be stacked.- Returns:
A new
State
instance where each attribute is a JAX array with an additional leading dimension representing the stacked trajectory. For example, if input pos was (N, dim), output pos will be (T, N, dim).- Return type:
- Raises:
ValueError – If the input states sequence is empty. If the stacked State is invalid.
AssertionError – If any input state is invalid, or if there is a mismatch in spatial dimension (dim), batch size (batch_size), or number of particles (N) between the states in the sequence.
Notes
No ID shifting is performed because the leading axis represents time (or another batch dimension), not new particles.
Example
>>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> # Create a sequence of 3 simple 2D snapshots >>> snapshot1 = jdem.State.create(pos=jnp.array([[0.,0.], [1.,1.]]), vel=jnp.array([[0.1,0.], [0.0,0.1]])) >>> snapshot2 = jdem.State.create(pos=jnp.array([[0.1,0.], [1.,1.1]]), vel=jnp.array([[0.1,0.], [0.0,0.1]])) >>> snapshot3 = jdem.State.create(pos=jnp.array([[0.2,0.], [1.,1.2]]), vel=jnp.array([[0.1,0.], [0.0,0.1]])) >>> >>> trajectory_state = State.stack([snapshot1, snapshot2, snapshot3]) >>> >>> print(f"Trajectory positions shape: {trajectory_state.pos.shape}") # Expected: (3, 2, 2) >>> print(f"Positions at time step 0:\n{trajectory_state.pos[0]}") >>> print(f"Positions at time step 1:\n{trajectory_state.pos[1]}")