Source code for jaxdem.state

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Defines the simulation State.
"""

from dataclasses import dataclass
from typing import Optional, final, Sequence

import jax
import jax.numpy as jnp
from jax.typing import ArrayLike


[docs] @final @jax.tree_util.register_dataclass @dataclass(slots=True) class State: """ Represents the complete simulation state for a system of N particles in 2D or 3D. Notes ----- `State` is designed to support various data layouts: - **Single snapshot:** `pos.shape = (N, dim)` for particle properties (e.g., `pos`, `vel`, `accel`), and `(N,)` for scalar properties (e.g., `rad`, `mass`). In this case, `batch_size` is 1. - **Batched states:** `pos.shape = (B, N, dim)` for particle properties, and `(B, N)` for scalar properties. Here, `B` is the batch dimension (`batch_size = pos.shape[0]`). - **Trajectories of a single simulation:** `pos.shape = (T, N, dim)` for particle properties, and `(T, N)` for scalar properties. Here, `T` is the trajectory dimension. - **Trajectories of batched states:** `pos.shape = (B, T_1, T_2, ..., T_k, N, dim)` for particle properties, and `(B, T_1, T_2, ..., T_k, N)` for scalar properties. - The first dimension (i.e., `pos.shape[0]`) is always interpreted as the **batch dimension (`B`)**. - All preceding leading dimensions (`T_1, T_2, ... T_k`) are interpreted as **trajectory dimensions** and they are **flattened at save time** if there is more than 1 trajectory dimension. The class is `final` and cannot be subclassed. Example ------- Creating a simple 2D state for 4 particles: >>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> positions = jnp.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) >>> state = jdem.State.create(pos=positions) >>> >>> print(f"Number of particles (N): {state.N}") >>> print(f"Spatial dimension (dim): {state.dim}") >>> print(f"Positions:\\n{state.pos}") Creating a batched state: >>> batched_state = jax.vmap(lambda _: State.create(pos=positions))(jnp.arange(10)) >>> >>> print(f"Batch size: {batched_state.batch_size}") # 10 >>> print(f"Positions shape: {batched_state.pos.shape}") """ pos: jax.Array """ Array of particle positions. Shape is `(..., N, dim)`. """ vel: jax.Array """ Array of particle velocities. Shape is `(..., N, dim)`. """ accel: jax.Array """ Array of particle accelerations. Shape is `(..., N, dim)`. """ rad: jax.Array """ Array of particle radii. Shape is `(..., N)`. """ mass: jax.Array """ Array of particle masses. Shape is `(..., N)`. """ ID: jax.Array """ Array of unique particle identifiers. Shape is `(..., N)`. """ mat_id: jax.Array """ Array of material IDs for each particle. Shape is `(..., N)`. """ species_id: jax.Array """ Array of species IDs for each particle. Shape is `(..., N)`. """ fixed: jax.Array """ Boolean array indicating if a particle is fixed (immobile). Shape is `(..., N)`. """ @property def N(self) -> int: """ Number of particles in the state. """ return self.pos.shape[-2] @property def dim(self) -> int: """ Spatial dimension of the simulation. """ return self.pos.shape[-1] @property def batch_size(self) -> int: """ Return the batch size of the state. """ return 1 if self.pos.ndim < 3 else self.pos.shape[0] @property def is_valid(self) -> bool: """ Check if the internal representation of the State is consistent. Verifies that: - The spatial dimension (`dim`) is either 2 or 3. - All position-like arrays (`pos`, `vel`, `accel`) have the same shape. - All scalar-per-particle arrays (`rad`, `mass`, `ID`, `mat_id`, `species_id`, `fixed`) have a shape consistent with `pos.shape[:-1]`. Raises ------ AssertionError If any shape inconsistency is found. """ valid = self.dim in (2, 3) assert valid, f"Simulation dimension (pos.shape[-1]={self.dim}) must be 2 or 3." for name in ( "pos", "vel", "accel", ): arr = getattr(self, name) valid = valid and self.pos.shape == arr.shape assert ( valid ), f"{name}.shape={arr.shape} is not equal to pos.shape={self.pos.shape}." for name in ( "rad", "mass", "ID", "mat_id", "species_id", "fixed", ): arr = getattr(self, name) valid = valid and self.pos.shape[:-1] == arr.shape assert ( valid ), f"{name}.shape={arr.shape} is not equal to pos.shape[:-1]={self.pos.shape[:-1]}." return valid def __init_subclass__(cls, *args, **kw): raise TypeError(f"{State.__name__} is final and cannot be subclassed")
[docs] @staticmethod def create( pos: ArrayLike, *, vel: Optional[ArrayLike] = None, accel: Optional[ArrayLike] = None, rad: Optional[ArrayLike] = None, mass: Optional[ArrayLike] = None, ID: Optional[ArrayLike] = None, mat_id: Optional[ArrayLike] = None, species_id: Optional[ArrayLike] = None, fixed: Optional[ArrayLike] = None, ) -> "State": """ Factory method to create a new :class:`State` instance. This method handles default values and ensures consistent array shapes for all state attributes. Parameters ---------- pos : jax.typing.ArrayLike Initial positions of particles. Expected shape: `(..., N, dim)`. vel : jax.typing.ArrayLike or None, optional Initial velocities of particles. If `None`, defaults to zeros. Expected shape: `(..., N, dim)`. accel : jax.typing.ArrayLike or None, optional Initial accelerations of particles. If `None`, defaults to zeros. Expected shape: `(..., N, dim)`. rad : jax.typing.ArrayLike or None, optional Radii of particles. If `None`, defaults to ones. Expected shape: `(..., N)`. mass : jax.typing.ArrayLike or None, optional Masses of particles. If `None`, defaults to ones. Expected shape: `(..., N)`. ID : jax.typing.ArrayLike or None, optional Unique identifiers for particles. If `None`, defaults to :func:`jnp.arange(N)`. Expected shape: `(..., N)`. mat_id : jax.typing.ArrayLike or None, optional Material IDs for particles. If `None`, defaults to zeros. Expected shape: `(..., N)`. species_id : jax.typing.ArrayLike or None, optional Species IDs for particles. If `None`, defaults to zeros. Expected shape: `(..., N)`. fixed : jax.typing.ArrayLike or None, optional Boolean array indicating fixed particles. If `None`, defaults to all `False`. Expected shape: `(..., N)`. Returns ------- State A new `State` instance with all attributes correctly initialized and shaped. Raises ------ ValueError If the created `State` is not valid. Example ------- Creating a 3D state for 5 particles: >>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> my_pos = jnp.array([[0.,0.,0.], [1.,0.,0.], [0.,1.,0.], [0.,0.,1.], [1.,1.,1.]]) >>> my_rad = jnp.array([0.5, 0.5, 0.5, 0.5, 0.5]) >>> my_mass = jnp.array([1.0, 1.0, 1.0, 1.0, 1.0]) >>> >>> state_5_particles = jdem.State.create(pos=my_pos, rad=my_rad, mass=my_mass) >>> print(f"Shape of positions: {state_5_particles.pos.shape}") >>> print(f"Radii: {state_5_particles.rad}") """ pos = jnp.asarray(pos, dtype=float) N = pos.shape[-2] vel = ( jnp.zeros_like(pos, dtype=float) if vel is None else jnp.asarray(vel, dtype=float) ) accel = ( jnp.zeros_like(pos, dtype=float) if accel is None else jnp.asarray(accel, dtype=float) ) rad = ( jnp.ones(pos.shape[:-1], dtype=float) if rad is None else jnp.asarray(rad, dtype=float) ) mass = ( jnp.ones(pos.shape[:-1], dtype=float) if mass is None else jnp.asarray(mass, dtype=float) ) ID = ( jnp.broadcast_to(jnp.arange(N, dtype=int), pos.shape[:-1]) if ID is None else jnp.asarray(ID, dtype=int) ) mat_id = ( jnp.zeros(pos.shape[:-1], dtype=int) if mat_id is None else jnp.asarray(mat_id, dtype=int) ) species_id = ( jnp.zeros(pos.shape[:-1], dtype=int) if species_id is None else jnp.asarray(species_id, dtype=int) ) fixed = ( jnp.zeros(pos.shape[:-1], dtype=bool) if fixed is None else jnp.asarray(fixed, dtype=bool) ) state = State( pos=pos, vel=vel, accel=accel, rad=rad, mass=mass, ID=ID, mat_id=mat_id, species_id=species_id, fixed=fixed, ) if not state.is_valid: raise ValueError(f"State is not valid, state={state}") return state
[docs] @staticmethod def merge(state1: "State", state2: "State") -> "State": """ Merges two :class:`State` instances into a single new :class:`State`. This method concatenates the particles from `state2` onto `state1`. Particle IDs in `state2` are shifted to ensure uniqueness in the merged state. Parameters ---------- state1 : State The first `State` instance. Its particles will appear first in the merged state. state2 : State The second `State` instance. Its particles will be appended to the first. Returns ------- State A new `State` instance containing all particles from both input states. Raises ------ AssertionError If either input state is invalid, or if there is a mismatch in spatial dimension (`dim`) or batch size (`batch_size`). ValueError If the resulting merged state is somehow invalid. Example ------- >>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> state_a = jdem.create(pos=jnp.array([[0.,0.], [1.,1.]]), ID=jnp.array([0, 1])) >>> state_b = jdem.State.create(pos=jnp.array([[2.,2.], [3.,3.]]), ID=jnp.array([0, 1])) >>> merged_state = State.merge(state_a, state_b) >>> >>> print(f"Merged state N: {merged_state.N}") # Expected: 4 >>> print(f"Merged state positions:\\n{merged_state.pos}") >>> print(f"Merged state IDs: {merged_state.ID}") # Expected: [0, 1, 2, 3] """ assert state1.is_valid and state2.is_valid, "One of the states is invalid" assert state1.dim == state2.dim, f"dim mismatch: {state1.dim} vs {state2.dim}" assert ( state1.batch_size == state2.batch_size ), f"batch_size mismatch: {state1.batch_size} vs {state2.batch_size}" state2.ID += jnp.max(state1.ID) + 1 # ----------------- tree-wise concatenation -------------------------- # Arrays that have the same rank as `pos` (`pos`, `vel`, `accel`) are # concatenated along axis -2 (particle axis). Everything else # (`rad`, `mass`, `ID`, `mat_id`, `species_id`, `fixed`) is concatenated along axis -1. pos_ndim = state1.pos.ndim def cat(a, b): axis = -2 if a.ndim == pos_ndim else -1 return jnp.concatenate((a, b), axis=axis) state = jax.tree_util.tree_map(cat, state1, state2) if not state.is_valid: raise ValueError(f"State is not valid, state={state}") return state
[docs] @staticmethod def add( state: "State", pos: ArrayLike, *, vel: Optional[ArrayLike] = None, accel: Optional[ArrayLike] = None, rad: Optional[ArrayLike] = None, mass: Optional[ArrayLike] = None, ID: Optional[ArrayLike] = None, mat_id: Optional[ArrayLike] = None, species_id: Optional[ArrayLike] = None, fixed: Optional[ArrayLike] = None, ) -> "State": """ Adds new particles to an existing :class:`State` instance, returning a new `State`. Parameters ---------- state : State The existing `State` to which particles will be added. pos : jax.typing.ArrayLike Positions of the new particle(s). Shape `(..., N_new, dim)`. vel : jax.typing.ArrayLike or None, optional Velocities of the new particle(s). Defaults to zeros. accel : jax.typing.ArrayLike or None, optional Accelerations of the new particle(s). Defaults to zeros. rad : jax.typing.ArrayLike or None, optional Radii of the new particle(s). Defaults to ones. mass : jax.typing.ArrayLike or None, optional Masses of the new particle(s). Defaults to ones. ID : jax.typing.ArrayLike or None, optional IDs of the new particle(s). If `None`, new IDs are generated. mat_id : jax.typing.ArrayLike or None, optional Material IDs of the new particle(s). Defaults to zeros. species_id : jax.typing.ArrayLike or None, optional Species IDs of the new particle(s). Defaults to zeros. fixed : jax.typing.ArrayLike or None, optional Fixed status of the new particle(s). Defaults to all `False`. Returns ------- State A new `State` instance containing all particles from the original `state` plus the newly added particles. Raises ------ ValueError If the created new particle state or the merged state is invalid. AssertionError If batch size or dimension mismatch between existing state and new particles. Example ------- >>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> # Initial state with 4 particles >>> state = jdem.State.create(jnp.zeros((4, 2))) >>> print(f"Original state N: {state.N}, IDs: {state.ID}") >>> >>> # Add a single new particle >>> state_with_added_particle = state.add(state, ... pos=jnp.array([[10., 10.]]), ... rad=jnp.array([0.5]), ... mass=jnp.array([2.0])) >>> print(f"New state N: {state_with_added_particle.N}, IDs: {state_with_added_particle.ID}") >>> print(f"New particle position: {state_with_added_particle.pos[-1]}") Adding multiple new particles: >>> state_multiple_added = state.add(state, ... pos=jnp.array([[10., 10.], [11., 11.], [12., 12.]])) >>> print(f"State with multiple added N: {state_multiple_added.N}, IDs: {state_multiple_added.ID}") """ state2 = State.create( pos, vel=vel, accel=accel, rad=rad, mass=mass, ID=ID, mat_id=mat_id, species_id=species_id, fixed=fixed, ) return State.merge(state, state2)
[docs] @staticmethod def stack(states: Sequence["State"]) -> "State": """ Concatenates a sequence of :class:`State` snapshots into a trajectory along axis 0. This method is useful for collecting simulation snapshots over time into a single `State` object where the leading dimension represents time. Parameters ---------- states : Sequence[State] A sequence (e.g., list, tuple) of :class:`State` instances to be stacked. Returns ------- State A new :class:`State` instance where each attribute is a JAX array with an additional leading dimension representing the stacked trajectory. For example, if input `pos` was `(N, dim)`, output `pos` will be `(T, N, dim)`. Raises ------ ValueError If the input `states` sequence is empty. If the stacked `State` is invalid. AssertionError If any input state is invalid, or if there is a mismatch in spatial dimension (`dim`), batch size (`batch_size`), or number of particles (`N`) between the states in the sequence. Notes ---- - No ID shifting is performed because the leading axis represents **time** (or another batch dimension), not new particles. Example ------- >>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> # Create a sequence of 3 simple 2D snapshots >>> snapshot1 = jdem.State.create(pos=jnp.array([[0.,0.], [1.,1.]]), vel=jnp.array([[0.1,0.], [0.0,0.1]])) >>> snapshot2 = jdem.State.create(pos=jnp.array([[0.1,0.], [1.,1.1]]), vel=jnp.array([[0.1,0.], [0.0,0.1]])) >>> snapshot3 = jdem.State.create(pos=jnp.array([[0.2,0.], [1.,1.2]]), vel=jnp.array([[0.1,0.], [0.0,0.1]])) >>> >>> trajectory_state = State.stack([snapshot1, snapshot2, snapshot3]) >>> >>> print(f"Trajectory positions shape: {trajectory_state.pos.shape}") # Expected: (3, 2, 2) >>> print(f"Positions at time step 0:\\n{trajectory_state.pos[0]}") >>> print(f"Positions at time step 1:\\n{trajectory_state.pos[1]}") """ states = list(states) if not states: raise ValueError("State.stack() received an empty list") ref = states[0] assert ref.is_valid, "first state is invalid" # ---------- consistency checks --------------------------------- for s in states[1:]: assert s.is_valid, "one state is invalid" assert s.dim == ref.dim, "dimension mismatch" assert s.batch_size == ref.batch_size, "batch size mismatch" assert s.N == ref.N, "particle count mismatch" # ---------- concatenate every leaf ----------------------------- stacked = jax.tree_util.tree_map(lambda *xs: jnp.stack(xs), *states) if not stacked.is_valid: raise ValueError("stacked State is not valid") return stacked