jaxdem.state#

Defines the simulation State.

Classes

State(pos, vel, accel, rad, mass, ID, ...)

Represents the complete simulation state for a system of N particles in 2D or 3D.

final class jaxdem.state.State(pos: Array, vel: Array, accel: Array, rad: Array, mass: Array, ID: Array, mat_id: Array, species_id: Array, fixed: Array)[source][source]#

Bases: object

Represents the complete simulation state for a system of N particles in 2D or 3D.

Notes

State is designed to support various data layouts:

  • Single snapshot:

    pos.shape = (N, dim) for particle properties (e.g., pos, vel, accel), and (N,) for scalar properties (e.g., rad, mass). In this case, batch_size is 1.

  • Batched states:

    pos.shape = (B, N, dim) for particle properties, and (B, N) for scalar properties. Here, B is the batch dimension (batch_size = pos.shape[0]).

  • Trajectories of a single simulation:

    pos.shape = (T, N, dim) for particle properties, and (T, N) for scalar properties. Here, T is the trajectory dimension.

  • Trajectories of batched states:

    pos.shape = (B, T_1, T_2, …, T_k, N, dim) for particle properties, and (B, T_1, T_2, …, T_k, N) for scalar properties.

    • The first dimension (i.e., pos.shape[0]) is always interpreted as the batch dimension (`B`).

    • All preceding leading dimensions (T_1, T_2, … T_k) are interpreted as trajectory dimensions

      and they are flattened at save time if there is more than 1 trajectory dimension.

The class is final and cannot be subclassed.

Example

Creating a simple 2D state for 4 particles:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> positions = jnp.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])
>>> state = jdem.State.create(pos=positions)
>>>
>>> print(f"Number of particles (N): {state.N}")
>>> print(f"Spatial dimension (dim): {state.dim}")
>>> print(f"Positions:\n{state.pos}")

Creating a batched state:

>>> batched_state = jax.vmap(lambda _: State.create(pos=positions))(jnp.arange(10))
>>>
>>> print(f"Batch size: {batched_state.batch_size}") # 10
>>> print(f"Positions shape: {batched_state.pos.shape}")
pos: Array#

Array of particle positions. Shape is (…, N, dim).

vel: Array#

Array of particle velocities. Shape is (…, N, dim).

accel: Array#

Array of particle accelerations. Shape is (…, N, dim).

rad: Array#

Array of particle radii. Shape is (…, N).

mass: Array#

Array of particle masses. Shape is (…, N).

ID: Array#

Array of unique particle identifiers. Shape is (…, N).

mat_id: Array#

Array of material IDs for each particle. Shape is (…, N).

species_id: Array#

Array of species IDs for each particle. Shape is (…, N).

fixed: Array#

Boolean array indicating if a particle is fixed (immobile). Shape is (…, N).

property N: int[source]#

Number of particles in the state.

property dim: int[source]#

Spatial dimension of the simulation.

property batch_size: int[source]#

Return the batch size of the state.

property is_valid: bool[source]#

Check if the internal representation of the State is consistent.

Verifies that:

  • The spatial dimension (dim) is either 2 or 3.

  • All position-like arrays (pos, vel, accel) have the same shape.

  • All scalar-per-particle arrays (rad, mass, ID, mat_id, species_id, fixed) have a shape consistent with pos.shape[:-1].

Raises:

AssertionError – If any shape inconsistency is found.

static create(pos: Array | ndarray | bool | number | bool | int | float | complex, *, vel: Array | ndarray | bool | number | bool | int | float | complex | None = None, accel: Array | ndarray | bool | number | bool | int | float | complex | None = None, rad: Array | ndarray | bool | number | bool | int | float | complex | None = None, mass: Array | ndarray | bool | number | bool | int | float | complex | None = None, ID: Array | ndarray | bool | number | bool | int | float | complex | None = None, mat_id: Array | ndarray | bool | number | bool | int | float | complex | None = None, species_id: Array | ndarray | bool | number | bool | int | float | complex | None = None, fixed: Array | ndarray | bool | number | bool | int | float | complex | None = None) State[source][source]#

Factory method to create a new State instance.

This method handles default values and ensures consistent array shapes for all state attributes.

Parameters:
  • pos (jax.typing.ArrayLike) – Initial positions of particles. Expected shape: (…, N, dim).

  • vel (jax.typing.ArrayLike or None, optional) – Initial velocities of particles. If None, defaults to zeros. Expected shape: (…, N, dim).

  • accel (jax.typing.ArrayLike or None, optional) – Initial accelerations of particles. If None, defaults to zeros. Expected shape: (…, N, dim).

  • rad (jax.typing.ArrayLike or None, optional) – Radii of particles. If None, defaults to ones. Expected shape: (…, N).

  • mass (jax.typing.ArrayLike or None, optional) – Masses of particles. If None, defaults to ones. Expected shape: (…, N).

  • ID (jax.typing.ArrayLike or None, optional) – Unique identifiers for particles. If None, defaults to jnp.arange(N)(). Expected shape: (…, N).

  • mat_id (jax.typing.ArrayLike or None, optional) – Material IDs for particles. If None, defaults to zeros. Expected shape: (…, N).

  • species_id (jax.typing.ArrayLike or None, optional) – Species IDs for particles. If None, defaults to zeros. Expected shape: (…, N).

  • fixed (jax.typing.ArrayLike or None, optional) – Boolean array indicating fixed particles. If None, defaults to all False. Expected shape: (…, N).

Returns:

A new State instance with all attributes correctly initialized and shaped.

Return type:

State

Raises:

ValueError – If the created State is not valid.

Example

Creating a 3D state for 5 particles:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> my_pos = jnp.array([[0.,0.,0.], [1.,0.,0.], [0.,1.,0.], [0.,0.,1.], [1.,1.,1.]])
>>> my_rad = jnp.array([0.5, 0.5, 0.5, 0.5, 0.5])
>>> my_mass = jnp.array([1.0, 1.0, 1.0, 1.0, 1.0])
>>>
>>> state_5_particles = jdem.State.create(pos=my_pos, rad=my_rad, mass=my_mass)
>>> print(f"Shape of positions: {state_5_particles.pos.shape}")
>>> print(f"Radii: {state_5_particles.rad}")
static merge(state1: State, state2: State) State[source][source]#

Merges two State instances into a single new State.

This method concatenates the particles from state2 onto state1. Particle IDs in state2 are shifted to ensure uniqueness in the merged state.

Parameters:
  • state1 (State) – The first State instance. Its particles will appear first in the merged state.

  • state2 (State) – The second State instance. Its particles will be appended to the first.

Returns:

A new State instance containing all particles from both input states.

Return type:

State

Raises:
  • AssertionError – If either input state is invalid, or if there is a mismatch in spatial dimension (dim) or batch size (batch_size).

  • ValueError – If the resulting merged state is somehow invalid.

Example

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> state_a = jdem.create(pos=jnp.array([[0.,0.], [1.,1.]]), ID=jnp.array([0, 1]))
>>> state_b = jdem.State.create(pos=jnp.array([[2.,2.], [3.,3.]]), ID=jnp.array([0, 1]))
>>> merged_state = State.merge(state_a, state_b)
>>>
>>> print(f"Merged state N: {merged_state.N}") # Expected: 4
>>> print(f"Merged state positions:\n{merged_state.pos}")
>>> print(f"Merged state IDs: {merged_state.ID}") # Expected: [0, 1, 2, 3]
static add(state: State, pos: Array | ndarray | bool | number | bool | int | float | complex, *, vel: Array | ndarray | bool | number | bool | int | float | complex | None = None, accel: Array | ndarray | bool | number | bool | int | float | complex | None = None, rad: Array | ndarray | bool | number | bool | int | float | complex | None = None, mass: Array | ndarray | bool | number | bool | int | float | complex | None = None, ID: Array | ndarray | bool | number | bool | int | float | complex | None = None, mat_id: Array | ndarray | bool | number | bool | int | float | complex | None = None, species_id: Array | ndarray | bool | number | bool | int | float | complex | None = None, fixed: Array | ndarray | bool | number | bool | int | float | complex | None = None) State[source][source]#

Adds new particles to an existing State instance, returning a new State.

Parameters:
  • state (State) – The existing State to which particles will be added.

  • pos (jax.typing.ArrayLike) – Positions of the new particle(s). Shape (…, N_new, dim).

  • vel (jax.typing.ArrayLike or None, optional) – Velocities of the new particle(s). Defaults to zeros.

  • accel (jax.typing.ArrayLike or None, optional) – Accelerations of the new particle(s). Defaults to zeros.

  • rad (jax.typing.ArrayLike or None, optional) – Radii of the new particle(s). Defaults to ones.

  • mass (jax.typing.ArrayLike or None, optional) – Masses of the new particle(s). Defaults to ones.

  • ID (jax.typing.ArrayLike or None, optional) – IDs of the new particle(s). If None, new IDs are generated.

  • mat_id (jax.typing.ArrayLike or None, optional) – Material IDs of the new particle(s). Defaults to zeros.

  • species_id (jax.typing.ArrayLike or None, optional) – Species IDs of the new particle(s). Defaults to zeros.

  • fixed (jax.typing.ArrayLike or None, optional) – Fixed status of the new particle(s). Defaults to all False.

Returns:

A new State instance containing all particles from the original state plus the newly added particles.

Return type:

State

Raises:
  • ValueError – If the created new particle state or the merged state is invalid.

  • AssertionError – If batch size or dimension mismatch between existing state and new particles.

Example

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> # Initial state with 4 particles
>>> state = jdem.State.create(jnp.zeros((4, 2)))
>>> print(f"Original state N: {state.N}, IDs: {state.ID}")
>>>
>>> # Add a single new particle
>>> state_with_added_particle = state.add(state,
...                                       pos=jnp.array([[10., 10.]]),
...                                       rad=jnp.array([0.5]),
...                                       mass=jnp.array([2.0]))
>>> print(f"New state N: {state_with_added_particle.N}, IDs: {state_with_added_particle.ID}")
>>> print(f"New particle position: {state_with_added_particle.pos[-1]}")

Adding multiple new particles:

>>> state_multiple_added = state.add(state,
...                                  pos=jnp.array([[10., 10.], [11., 11.], [12., 12.]]))
>>> print(f"State with multiple added N: {state_multiple_added.N}, IDs: {state_multiple_added.ID}")
static stack(states: Sequence[State]) State[source][source]#

Concatenates a sequence of State snapshots into a trajectory along axis 0.

This method is useful for collecting simulation snapshots over time into a single State object where the leading dimension represents time.

Parameters:

states (Sequence[State]) – A sequence (e.g., list, tuple) of State instances to be stacked.

Returns:

A new State instance where each attribute is a JAX array with an additional leading dimension representing the stacked trajectory. For example, if input pos was (N, dim), output pos will be (T, N, dim).

Return type:

State

Raises:
  • ValueError – If the input states sequence is empty. If the stacked State is invalid.

  • AssertionError – If any input state is invalid, or if there is a mismatch in spatial dimension (dim), batch size (batch_size), or number of particles (N) between the states in the sequence.

Notes

  • No ID shifting is performed because the leading axis represents time (or another batch dimension), not new particles.

Example

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> # Create a sequence of 3 simple 2D snapshots
>>> snapshot1 = jdem.State.create(pos=jnp.array([[0.,0.], [1.,1.]]), vel=jnp.array([[0.1,0.], [0.0,0.1]]))
>>> snapshot2 = jdem.State.create(pos=jnp.array([[0.1,0.], [1.,1.1]]), vel=jnp.array([[0.1,0.], [0.0,0.1]]))
>>> snapshot3 = jdem.State.create(pos=jnp.array([[0.2,0.], [1.,1.2]]), vel=jnp.array([[0.1,0.], [0.0,0.1]]))
>>>
>>> trajectory_state = State.stack([snapshot1, snapshot2, snapshot3])
>>>
>>> print(f"Trajectory positions shape: {trajectory_state.pos.shape}") # Expected: (3, 2, 2)
>>> print(f"Positions at time step 0:\n{trajectory_state.pos[0]}")
>>> print(f"Positions at time step 1:\n{trajectory_state.pos[1]}")