jaxdem.system#

Defines the simulation configuration and the tooling for driving the simulation.

Classes

System(linear_integrator, ...)

Encapsulates the entire simulation configuration.

final class jaxdem.system.System(linear_integrator: LinearIntegrator | LinearMinimizer, rotation_integrator: RotationIntegrator | RotationMinimizer, collider: Collider, domain: Domain, force_manager: ForceManager, force_model: ForceModel, mat_table: MaterialTable, dt: jax.Array, time: jax.Array, dim: jax.Array, step_count: jax.Array, key: jax.Array)[source]#

Bases: object

Encapsulates the entire simulation configuration.

Notes

  • The System object is designed to be JIT-compiled for efficient execution.

  • The System dataclass is compatible with jax.jit(), so every field should remain JAX arrays for best performance.

Example

Creating a basic 2D simulation system:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> # Create a System instance
>>> sim_system = jdem.System.create(
>>>     state_shape=state.shape,
>>>     dt=0.001,
>>>     linear_integrator_type="euler",
>>>     rotation_integrator_type="spiral",
>>>     collider_type="naive",
>>>     domain_type="free",
>>>     force_model_type="spring",
>>>     # You can pass keyword arguments to component constructors via '_kw' dicts
>>>     domain_kw=dict(box_size=jnp.array([5.0, 5.0]), anchor=jnp.array([0.0, 0.0]))
>>> )
>>>
>>> print(f"System integrator: {sim_system.linear_integrator.__class__.__name__}")
>>> print(f"System force model: {sim_system.force_model.__class__.__name__}")
>>> print(f"Domain box size: {sim_system.domain.box_size}")
linear_integrator: LinearIntegrator | LinearMinimizer#

Instance of jaxdem.LinearIntegrator that advances the simulation linear state in time.

rotation_integrator: RotationIntegrator | RotationMinimizer#

Instance of jaxdem.RotationIntegrator that advances the simulation angular state in time.

collider: Collider#

Instance of jaxdem.Collider that performs contact detection and computes inter-particle forces and potential energies.

domain: Domain#

Instance of jaxdem.Domain that defines the simulation boundaries, displacement rules, and boundary conditions.

force_manager: ForceManager#

Instance of jaxdem.ForceManager that handles per particle forces like external forces and resets forces.

force_model: ForceModel#

Instance of jaxdem.ForceModel that defines the physical laws for inter-particle interactions.

mat_table: MaterialTable#

Instance of jaxdem.MaterialTable holding material properties and pairwise interaction parameters.

dt: jax.Array#

The global simulation time step \(\Delta t\).

time: jax.Array#

Elapsed simulation time.

dim: jax.Array#

Spatial dimension of the system.

step_count: jax.Array#

Number of integration steps that have been performed.

key: jax.Array#

PRNG key supporting stochastic functionality. Always update using split to ensure new numbers are generated.

static create(state_shape: Tuple[int, ...], *, dt: float = 0.005, time: float = 0.0, linear_integrator_type: str = 'verlet', rotation_integrator_type: str = 'verletspiral', collider_type: str = 'naive', domain_type: str = 'free', force_model_type: str = 'spring', force_manager_kw: Dict[str, Any] | None = None, mat_table: MaterialTable | None = None, linear_integrator_kw: Dict[str, Any] | None = None, rotation_integrator_kw: Dict[str, Any] | None = None, collider_kw: Dict[str, Any] | None = None, domain_kw: Dict[str, Any] | None = None, force_model_kw: Dict[str, Any] | None = None, seed: int = 0, key: Array | None = None) System[source][source]#

Factory method to create a System instance with specified components.

Parameters:
  • state_shape (Tuple) – Shape of the state tensors handled by the simulation. The penultimate dimension corresponds to the number of particles N and the last dimension corresponds to the spatial dimension dim.

  • dt (float, optional) – The global simulation time step.

  • linear_integrator_type (str, optional) – The registered type string for the jaxdem.integrators.LinearIntegrator used to evolve translational degrees of freedom.

  • rotation_integrator_type (str, optional) – The registered type string for the jaxdem.integrators.RotationIntegrator used to evolve angular degrees of freedom.

  • collider_type (str, optional) – The registered type string for the jaxdem.Collider to use.

  • domain_type (str, optional) – The registered type string for the jaxdem.Domain to use.

  • force_model_type (str, optional) – The registered type string for the jaxdem.ForceModel to use.

  • force_manager_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of ForceManager.

  • mat_table (MaterialTable or None, optional) – An optional pre-configured jaxdem.MaterialTable. If None, a default jaxdem.MaterialTable will be created with one generic elastic material and “harmonic” jaxdem.MaterialMatchmaker.

  • linear_integrator_kw (Dict[str, Any] or None, optional) – Keyword arguments forwarded to the constructor of the selected LinearIntegrator type.

  • rotation_integrator_kw (Dict[str, Any] or None, optional) – Keyword arguments forwarded to the constructor of the selected RotationIntegrator type.

  • collider_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected Collider type.

  • domain_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected Domain type.

  • force_model_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected ForceModel type.

  • seed (int, optional) – Integer seed used for random number generation. Defaults to 0.

  • key (jax.Array, optional) – Key used for the jax random number generation. Defaults to None, shadowed by seed.

Returns:

A fully configured System instance ready for simulation.

Return type:

System

Raises:
  • KeyError – If a specified *_type is not registered in its respective factory, or if the mat_table is missing properties required by the force_model.

  • TypeError – If constructor keyword arguments are invalid for any component.

  • AssertionError – If the domain_kw ‘box_size’ or ‘anchor’ shapes do not match the dim.

Example

Creating a 3D system with reflective boundaries and a custom dt:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> system_reflect = jdem.System.create(
>>>     state_shape=(N, 3),
>>>     dt=0.0005,
>>>     domain_type="reflect",
>>>     domain_kw=dict(box_size=jnp.array([20.0, 20.0, 20.0]), anchor=jnp.array([-10.0, -10.0, -10.0])),
>>>     force_model_type="spring",
>>> )
>>> print(f"System dt: {system_reflect.dt}")
>>> print(f"Domain type: {system_reflect.domain.__class__.__name__}")

Creating a system with a pre-defined MaterialTable:

>>> custom_mat_kw = dict(young=2.0e5, poisson=0.25)
>>> custom_material = jdem.Material.create("custom_mat", **custom_mat_kw)
>>> custom_mat_table = jdem.MaterialTable.from_materials(
...     [custom_material], matcher=jdem.MaterialMatchmaker.create("linear")
... )
>>>
>>> system_custom_mat = jdem.System.create(
...     state_shape=(N, 2),
...     mat_table=custom_mat_table,
...     force_model_type="spring"
... )
static trajectory_rollout(state: State, system: System, *, n: Optional[int] = None, stride: int = 1, strides: Optional[jax.Array] = None, save_fn: Callable[[State, System], Any] = <function _save_state_system>, unroll: int = 2) Tuple[State, System, Any][source][source]#

Roll the system forward while collecting saved outputs at each frame.

The rollout always stores one output per frame via save_fn(state, system). The output of save_fn must be a pytree. Frame spacing can be either: - constant (stride), or - variable (strides jax.Array).

Parameters:
  • state (State) – Initial state.

  • system (System) – Initial system configuration.

  • n (int, optional) – Number of saved frames. Required when strides is None. Ignored when strides is provided.

  • stride (int, optional) – Constant number of integration steps between consecutive saves. Used only when strides is None. Defaults to 1.

  • strides (jax.Array, optional) – Integer 1D array of per-frame integration strides. When provided, this overrides stride, and n is inferred from len(strides).

  • save_fn (Callable[[State, System], Any], optional) – Function called after each saved frame. Its return pytree is stacked along axis 0 across frames. Defaults to returning (state, system).

  • unroll (int, optional) – Unroll factor passed to the outer jax.lax.scan. Defaults to 2.

Returns:

(final_state, final_system, trajectory_like) where trajectory_like is the stacked output of save_fn.

Return type:

Tuple[State, System, Any]

Raises:

ValueError – If n is missing while strides is None, or if strides is not 1D.

Example

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> state = jdem.utils.grid_state(n_per_axis=(1, 1), spacing=1.0, radius=0.1)
>>> system = jdem.System.create(state_shape=state.shape, dt=0.01)
>>>
>>> # Constant stride: n is required
>>> final_state, final_system, traj = jdem.System.trajectory_rollout(
...     state, system, n=10, stride=5
... )
>>>
>>> # Variable strides: n inferred from len(strides)
>>> deltas = jnp.array([1, 2, 4, 8])
>>> final_state, final_system, traj = jdem.System.trajectory_rollout(
...     state, system, strides=deltas
... )
static step(state: State, system: System, *, n: int | jax.Array = 1) Tuple[State, System][source][source]#

Advance the simulation by n integration steps.

Parameters:
  • state (State) – Current state.

  • system (System) – Current system configuration.

  • n (int or jax.Array, optional) – Number of integration steps. May be a Python int or a scalar JAX array. Defaults to 1.

Returns:

(final_state, final_system) after n steps.

Return type:

Tuple[State, System]

Example

>>> # Advance by 10 steps
>>> state_after_10_steps, system_after_10_steps = jdem.System.step(state, system, n=10)
static stack(systems: Sequence[System]) System[source][source]#

Concatenates a sequence of System snapshots into a trajectory or batch along axis 0.

This method is useful for collecting simulation snapshots over time into a single System object where the leading dimension represents time or when preparing a batched system.

Parameters:

systems (Sequence[System]) – A sequence (e.g., list, tuple) of System instances to be stacked.

Returns:

A new System instance where each attribute is a JAX array with an additional leading dimension representing the stacked trajectory. For example, if input pos was (N, dim), output pos will be (T, N, dim).

Return type:

System

static unstack(system: System) list[System][source][source]#

Split a stacked/batched System along the leading axis into a Python list.

This is the convenient inverse of System.stack():

  • If stacked = System.stack([sys0, sys1, …]), then System.unstack(stacked) returns [sys0, sys1, …].

Notes

  • The split is performed along axis 0 (the leading axis).

  • A single snapshot System cannot be unstacked with this method.