jaxdem.state#

Defines the simulation State.

Classes

State(pos, vel, accel, rad, mass, ID, ...)

Represents the complete simulation state for a system of N particles in 2D or 3D.

final class jaxdem.state.State(pos: Array, vel: Array, accel: Array, rad: Array, mass: Array, ID: Array, mat_id: Array, species_id: Array, fixed: Array)[source]#

Bases: object

Represents the complete simulation state for a system of N particles in 2D or 3D.

Notes

State is designed to support various data layouts:

  • Single snapshot:

    pos.shape = (N, dim) for particle properties (e.g., pos, vel, accel), and (N,) for scalar properties (e.g., rad, mass). In this case, batch_size is 1.

  • Batched states:

    pos.shape = (B, N, dim) for particle properties, and (B, N) for scalar properties. Here, B is the batch dimension (batch_size = pos.shape[0]).

  • Trajectories of a single simulation:

    pos.shape = (T, N, dim) for particle properties, and (T, N) for scalar properties. Here, T is the trajectory dimension.

  • Trajectories of batched states:

    pos.shape = (B, T_1, T_2, …, T_k, N, dim) for particle properties, and (B, T_1, T_2, …, T_k, N) for scalar properties.

    • The first dimension (i.e., pos.shape[0]) is always interpreted as the batch dimension (`B`).

    • All preceding leading dimensions (T_1, T_2, … T_k) are interpreted as trajectory dimensions

      and they are flattened at save time if there is more than 1 trajectory dimension.

The class is final and cannot be subclassed.

Example

Creating a simple 2D state for 4 particles:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>> import jax
>>>
>>> positions = jnp.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])
>>> state = jdem.State.create(pos=positions)
>>>
>>> print(f"Number of particles (N): {state.N}")
>>> print(f"Spatial dimension (dim): {state.dim}")
>>> print(f"Positions: {state.pos}")

Creating a batched state:

>>> batched_state = jax.vmap(lambda _: jdem.State.create(pos=positions))(jnp.arange(10))
>>>
>>> print(f"Batch size: {batched_state.batch_size}")  # 10
>>> print(f"Positions shape: {batched_state.pos.shape}")
pos: Array#

Array of particle positions. Shape is (…, N, dim).

vel: Array#

Array of particle velocities. Shape is (…, N, dim).

accel: Array#

Array of particle accelerations. Shape is (…, N, dim).

rad: Array#

Array of particle radii. Shape is (…, N).

mass: Array#

Array of particle masses. Shape is (…, N).

ID: Array#

Array of unique particle identifiers. Shape is (…, N).

mat_id: Array#

Array of material IDs for each particle. Shape is (…, N).

species_id: Array#

Array of species IDs for each particle. Shape is (…, N).

fixed: Array#

Boolean array indicating if a particle is fixed (immobile). Shape is (…, N).

property N: int[source]#

Number of particles in the state.

property dim: int[source]#

Spatial dimension of the simulation.

property batch_size: int[source]#

Return the batch size of the state.

property is_valid: bool[source]#

Check if the internal representation of the State is consistent.

Verifies that:

  • The spatial dimension (dim) is either 2 or 3.

  • All position-like arrays (pos, vel, accel) have the same shape.

  • All scalar-per-particle arrays (rad, mass, ID, mat_id, species_id, fixed) have a shape consistent with pos.shape[:-1].

Raises:

AssertionError – If any shape inconsistency is found.

static create(pos: Array | ndarray | bool | number | bool | int | float | complex, *, vel: Array | ndarray | bool | number | bool | int | float | complex | None = None, accel: Array | ndarray | bool | number | bool | int | float | complex | None = None, rad: Array | ndarray | bool | number | bool | int | float | complex | None = None, mass: Array | ndarray | bool | number | bool | int | float | complex | None = None, ID: Array | ndarray | bool | number | bool | int | float | complex | None = None, mat_id: Array | ndarray | bool | number | bool | int | float | complex | None = None, species_id: Array | ndarray | bool | number | bool | int | float | complex | None = None, fixed: Array | ndarray | bool | number | bool | int | float | complex | None = None) State[source][source]#

Factory method to create a new State instance.

This method handles default values and ensures consistent array shapes for all state attributes.

Parameters:
  • pos (jax.typing.ArrayLike) – Initial positions of particles. Expected shape: (…, N, dim).

  • vel (jax.typing.ArrayLike or None, optional) – Initial velocities of particles. If None, defaults to zeros. Expected shape: (…, N, dim).

  • accel (jax.typing.ArrayLike or None, optional) – Initial accelerations of particles. If None, defaults to zeros. Expected shape: (…, N, dim).

  • rad (jax.typing.ArrayLike or None, optional) – Radii of particles. If None, defaults to ones. Expected shape: (…, N).

  • mass (jax.typing.ArrayLike or None, optional) – Masses of particles. If None, defaults to ones. Expected shape: (…, N).

  • ID (jax.typing.ArrayLike or None, optional) – Unique identifiers for particles. If None, defaults to jnp.arange(). Expected shape: (…, N).

  • mat_id (jax.typing.ArrayLike or None, optional) – Material IDs for particles. If None, defaults to zeros. Expected shape: (…, N).

  • species_id (jax.typing.ArrayLike or None, optional) – Species IDs for particles. If None, defaults to zeros. Expected shape: (…, N).

  • fixed (jax.typing.ArrayLike or None, optional) – Boolean array indicating fixed particles. If None, defaults to all False. Expected shape: (…, N).

Returns:

A new State instance with all attributes correctly initialized and shaped.

Return type:

State

Raises:

ValueError – If the created State is not valid.

Example

Creating a 3D state for 5 particles:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> my_pos = jnp.array([[0.,0.,0.], [1.,0.,0.], [0.,1.,0.], [0.,0.,1.], [1.,1.,1.]])
>>> my_rad = jnp.array([0.5, 0.5, 0.5, 0.5, 0.5])
>>> my_mass = jnp.array([1.0, 1.0, 1.0, 1.0, 1.0])
>>>
>>> state_5_particles = jdem.State.create(pos=my_pos, rad=my_rad, mass=my_mass)
>>> print(f"Shape of positions: {state_5_particles.pos.shape}")
>>> print(f"Radii: {state_5_particles.rad}")
static merge(state1: State, state2: State) State[source][source]#

Merges two State instances into a single new State.

This method concatenates the particles from state2 onto state1. Particle IDs in state2 are shifted to ensure uniqueness in the merged state.

state1State

The first State instance. Its particles will appear first in the merged state.

state2State

The second State instance. Its particles will be appended to the first.

State

A new State instance containing all particles from both input states.

AssertionError

If either input state is invalid, or if there is a mismatch in spatial dimension (dim) or batch size (batch_size).

ValueError

If the resulting merged state is somehow invalid.

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> state_a = jdem.State.create(pos=jnp.array([[0.0, 0.0], [1.0, 1.0]]), ID=jnp.array([0, 1]))
>>> state_b = jdem.State.create(pos=jnp.array([[2.0, 2.0], [3.0, 3.0]]), ID=jnp.array([0, 1]))
>>> merged_state = jdem.State.merge(state_a, state_b)
>>>
>>> print(f"Merged state N: {merged_state.N}")  # Expected: 4
>>> print(f"Merged state positions:
{merged_state.pos}”)
>>> print(f"Merged state IDs: {merged_state.ID}")  # Expected: [0, 1, 2, 3]
static add(state: State, pos: Array | ndarray | bool | number | bool | int | float | complex, *, vel: Array | ndarray | bool | number | bool | int | float | complex | None = None, accel: Array | ndarray | bool | number | bool | int | float | complex | None = None, rad: Array | ndarray | bool | number | bool | int | float | complex | None = None, mass: Array | ndarray | bool | number | bool | int | float | complex | None = None, ID: Array | ndarray | bool | number | bool | int | float | complex | None = None, mat_id: Array | ndarray | bool | number | bool | int | float | complex | None = None, species_id: Array | ndarray | bool | number | bool | int | float | complex | None = None, fixed: Array | ndarray | bool | number | bool | int | float | complex | None = None) State[source][source]#

Adds new particles to an existing State instance, returning a new State.

Parameters:
  • state (State) – The existing State to which particles will be added.

  • pos (jax.typing.ArrayLike) – Positions of the new particle(s). Shape (…, N_new, dim).

  • vel (jax.typing.ArrayLike or None, optional) – Velocities of the new particle(s). Defaults to zeros.

  • accel (jax.typing.ArrayLike or None, optional) – Accelerations of the new particle(s). Defaults to zeros.

  • rad (jax.typing.ArrayLike or None, optional) – Radii of the new particle(s). Defaults to ones.

  • mass (jax.typing.ArrayLike or None, optional) – Masses of the new particle(s). Defaults to ones.

  • ID (jax.typing.ArrayLike or None, optional) – IDs of the new particle(s). If None, new IDs are generated.

  • mat_id (jax.typing.ArrayLike or None, optional) – Material IDs of the new particle(s). Defaults to zeros.

  • species_id (jax.typing.ArrayLike or None, optional) – Species IDs of the new particle(s). Defaults to zeros.

  • fixed (jax.typing.ArrayLike or None, optional) – Fixed status of the new particle(s). Defaults to all False.

Returns:

A new State instance containing all particles from the original state plus the newly added particles.

Return type:

State

Raises:
  • ValueError – If the created new particle state or the merged state is invalid.

  • AssertionError – If batch size or dimension mismatch between existing state and new particles.

Example

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> # Initial state with 4 particles
>>> state = jdem.State.create(pos=jnp.zeros((4, 2)))
>>> print(f"Original state N: {state.N}, IDs: {state.ID}")
>>>
>>> # Add a single new particle
>>> state_with_added_particle = jdem.State.add(
...     state,
...     pos=jnp.array([[10.0, 10.0]]),
...     rad=jnp.array([0.5]),
...     mass=jnp.array([2.0]),
... )
>>> print(f"New state N: {state_with_added_particle.N}, IDs: {state_with_added_particle.ID}")
>>> print(f"New particle position: {state_with_added_particle.pos[-1]}")
>>>
>>> # Add multiple new particles
>>> state_multiple_added = jdem.State.add(
...     state,
...     pos=jnp.array([[10.0, 10.0], [11.0, 11.0], [12.0, 12.0]]),
... )
>>> print(f"State with multiple added N: {state_multiple_added.N}, IDs: {state_multiple_added.ID}")
static stack(states: Sequence[State]) State[source][source]#

Concatenates a sequence of State snapshots into a trajectory along axis 0.

This method is useful for collecting simulation snapshots over time into a single State object where the leading dimension represents time.

Parameters:

states (Sequence[State]) – A sequence (e.g., list, tuple) of State instances to be stacked.

Returns:

A new State instance where each attribute is a JAX array with an additional leading dimension representing the stacked trajectory. For example, if input pos was (N, dim), output pos will be (T, N, dim).

Return type:

State

Raises:
  • ValueError – If the input states sequence is empty. If the stacked State is invalid.

  • AssertionError – If any input state is invalid, or if there is a mismatch in spatial dimension (dim), batch size (batch_size), or number of particles (N) between the states in the sequence.

Notes

  • No ID shifting is performed because the leading axis represents time (or another batch dimension), not new particles.

Example

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> # Create a sequence of 3 simple 2D snapshots
>>> snapshot1 = jdem.State.create(pos=jnp.array([[0.,0.], [1.,1.]]), vel=jnp.array([[0.1,0.], [0.0,0.1]]))
>>> snapshot2 = jdem.State.create(pos=jnp.array([[0.1,0.], [1.,1.1]]), vel=jnp.array([[0.1,0.], [0.0,0.1]]))
>>> snapshot3 = jdem.State.create(pos=jnp.array([[0.2,0.], [1.,1.2]]), vel=jnp.array([[0.1,0.], [0.0,0.1]]))
>>>
>>> trajectory_state = State.stack([snapshot1, snapshot2, snapshot3])
>>>
>>> print(f"Trajectory positions shape: {trajectory_state.pos.shape}") # Expected: (3, 2, 2)
>>> print(f"Positions at time step 0:\n{trajectory_state.pos[0]}")
>>> print(f"Positions at time step 1:\n{trajectory_state.pos[1]}")