jaxdem.system#

Defines the simulation configuration and the tooling for driving the simulation.

Classes

System(linear_integrator, ...)

Encapsulates the entire simulation configuration.

final class jaxdem.system.System(linear_integrator: LinearIntegrator, rotation_integrator: RotationIntegrator, collider: Collider, domain: Domain, force_manager: ForceManager, force_model: ForceModel, mat_table: MaterialTable, dt: Array, time: Array, dim: Array, step_count: Array)[source]#

Bases: object

Encapsulates the entire simulation configuration.

Notes

  • The System object is designed to be JIT-compiled for efficient execution.

  • The System dataclass is compatible with jax.jit(), so every field should remain JAX arrays for best performance.

Example

Creating a basic 2D simulation system:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> # Create a System instance
>>> sim_system = jdem.System.create(
>>>     state_shape=state.shape,
>>>     dt=0.001,
>>>     linear_integrator_type="euler",
>>>     rotation_integrator_type="spiral",
>>>     collider_type="naive",
>>>     domain_type="free",
>>>     force_model_type="spring",
>>>     # You can pass keyword arguments to component constructors via '_kw' dicts
>>>     domain_kw=dict(box_size=jnp.array([5.0, 5.0]), anchor=jnp.array([0.0, 0.0]))
>>> )
>>>
>>> print(f"System integrator: {sim_system.linear_integrator.__class__.__name__}")
>>> print(f"System force model: {sim_system.force_model.__class__.__name__}")
>>> print(f"Domain box size: {sim_system.domain.box_size}")
linear_integrator: LinearIntegrator#

Instance of jaxdem.LinearIntegrator that advances the simulation linear state in time.

rotation_integrator: RotationIntegrator#

Instance of jaxdem.RotationIntegrator that advances the simulation angular state in time.

collider: Collider#

Instance of jaxdem.Collider that performs contact detection and computes inter-particle forces and potential energies.

domain: Domain#

Instance of jaxdem.Domain that defines the simulation boundaries, displacement rules, and boundary conditions.

force_manager: ForceManager#

Instance of jaxdem.ForceManager that handles per particle forces like external forces and resets forces.

force_model: ForceModel#

Instance of jaxdem.ForceModel that defines the physical laws for inter-particle interactions.

mat_table: MaterialTable#

Instance of jaxdem.MaterialTable holding material properties and pairwise interaction parameters.

dt: Array#

The global simulation time step \(\Delta t\).

time: Array#

Elapsed simulation time.

dim: Array#

Spatial dimension of the system.

step_count: Array#

Number of integration steps that have been performed.

static create(state_shape: Tuple[int], *, dt: float = 0.005, time: float = 0.0, linear_integrator_type: str = 'verlet', rotation_integrator_type: str = 'verletspiral', collider_type: str = 'naive', domain_type: str = 'free', force_model_type: str = 'spring', force_manager_kw: Dict[str, Any] | None = None, mat_table: MaterialTable | None = None, linear_integrator_kw: Dict[str, Any] | None = None, rotation_integrator_kw: Dict[str, Any] | None = None, collider_kw: Dict[str, Any] | None = None, domain_kw: Dict[str, Any] | None = None, force_model_kw: Dict[str, Any] | None = None) System[source][source]#

Factory method to create a System instance with specified components.

Parameters:
  • state_shape (Tuple) – Shape of the state tensors handled by the simulation. The penultimate dimension corresponds to the number of particles N and the last dimension corresponds to the spatial dimension dim.

  • dt (float, optional) – The global simulation time step.

  • linear_integrator_type (str, optional) – The registered type string for the jaxdem.integrators.LinearIntegrator used to evolve translational degrees of freedom.

  • rotation_integrator_type (str, optional) – The registered type string for the jaxdem.integrators.RotationIntegrator used to evolve angular degrees of freedom.

  • collider_type (str, optional) – The registered type string for the jaxdem.Collider to use.

  • domain_type (str, optional) – The registered type string for the jaxdem.Domain to use.

  • force_model_type (str, optional) – The registered type string for the jaxdem.ForceModel to use.

  • force_manager_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of ForceManager.

  • mat_table (MaterialTable or None, optional) – An optional pre-configured jaxdem.MaterialTable. If None, a default jaxdem.MaterialTable will be created with one generic elastic material and “harmonic” jaxdem.MaterialMatchmaker.

  • linear_integrator_kw (Dict[str, Any] or None, optional) – Keyword arguments forwarded to the constructor of the selected LinearIntegrator type.

  • rotation_integrator_kw (Dict[str, Any] or None, optional) – Keyword arguments forwarded to the constructor of the selected RotationIntegrator type.

  • collider_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected Collider type.

  • domain_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected Domain type.

  • force_model_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected ForceModel type.

Returns:

A fully configured System instance ready for simulation.

Return type:

System

Raises:
  • KeyError – If a specified *_type is not registered in its respective factory, or if the mat_table is missing properties required by the force_model.

  • TypeError – If constructor keyword arguments are invalid for any component.

  • AssertionError – If the domain_kw ‘box_size’ or ‘anchor’ shapes do not match the dim.

Example

Creating a 3D system with reflective boundaries and a custom dt:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> system_reflect = jdem.System.create(
>>>     state_shape=(N, 3),
>>>     dt=0.0005,
>>>     domain_type="reflect",
>>>     domain_kw=dict(box_size=jnp.array([20.0, 20.0, 20.0]), anchor=jnp.array([-10.0, -10.0, -10.0])),
>>>     force_model_type="spring",
>>> )
>>> print(f"System dt: {system_reflect.dt}")
>>> print(f"Domain type: {system_reflect.domain.__class__.__name__}")

Creating a system with a pre-defined MaterialTable:

>>> custom_mat_kw = dict(young=2.0e5, poisson=0.25)
>>> custom_material = jdem.Material.create("custom_mat", **custom_mat_kw)
>>> custom_mat_table = jdem.MaterialTable.from_materials(
...     [custom_material], matcher=jdem.MaterialMatchmaker.create("linear")
... )
>>>
>>> system_custom_mat = jdem.System.create(
...     state_shape=(N, 2),
...     mat_table=custom_mat_table,
...     force_model_type="spring"
... )
static trajectory_rollout(state: State, system: System, *, n: int, unroll: int = 2, stride: int = 1) Tuple['State', 'System', Tuple['State', 'System']][source][source]#

Rolls the system forward for a specified number of frames, collecting a trajectory.

This method performs n * stride total simulation steps, but it saves the State and System every stride steps, returning a trajectory of n snapshots. This is highly efficient for data collection within JAX as it leverages jax.lax.scan.

stateState

The initial state of the simulation.

systemSystem

The initial system configuration.

nint

The number of frames (snapshots) to collect in the trajectory. This argument must be static.

strideint, optional

The number of integration steps to advance between each collected frame. Defaults to 1, meaning every step is collected. This argument must be static.

Tuple[State, System, Tuple[State, System]]

A tuple containing:

  • final_state:

    The State object at the end of the rollout.

  • final_system:

    The System object at the end of the rollout.

  • trajectory:

    A tuple of State and System objects, where each leaf array has an additional leading axis of size n representing the trajectory. The State and System objects within trajectory are structured as if created by jaxdem.State.stack() and a similar jaxdem.System stack.

ValueError

If n or stride are non-positive. (Implicit by jax.lax.scan length)

>>> import jaxdem as jdem
>>>
>>> state = jdem.utils.grid_state(n_per_axis=(1, 1), spacing=1.0, radius=0.1)
>>> system = jdem.System.create(state_shape=state.shape, dt=0.01)
>>>
>>> # Roll out for 10 frames, saving every 5 steps
>>> final_state, final_system, traj = jdem.System.trajectory_rollout(
...     state, system, n=10, stride=5
... )
>>>
>>> print(f"Total simulation steps performed: {10 * 5}")
>>> print(f"Trajectory length (number of frames): {traj[0].pos.shape[0]}")  # traj[0] is the state part of the trajectory
>>> print(f"First frame position:
{traj[0].pos[0]}”)
>>> print(f"Last frame position:
{traj[0].pos[-1]}”)
>>> print(f"Final state position (should match last frame):

{final_state.pos}”)

static step(state: State, system: System, *, n: int = 1, unroll: int = 2) Tuple['State', 'System'][source][source]#

Advances the simulation state by n time steps.

This method provides a convenient way to run multiple integration steps. For a single step (n=1), it directly calls the integrator’s step method. For multiple steps (n > 1), it uses an optimized internal loop based on jax.lax.scan to maintain JIT-compilation efficiency.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The current system configuration.

  • n (int, optional) – The number of integration steps to perform. Defaults to 1. This argument must be static.

Returns:

A tuple containing the final State and System after n integration steps.

Return type:

Tuple[State, System]

Raises:

ValueError – If n is non-positive. (Implicit by jax.lax.scan length check).

Example

>>> import jaxdem as jdem
>>>
>>> state = jdem.utils.grid_state(n_per_axis=(1, 1), spacing=1.0, radius=0.1)
>>> system = jdem.System.create(state.force.shape, dt=0.01)
>>>
>>> # Advance by 1 step
>>> state_after_1_step, system_after_1_step = jdem.System.step(state, system)
>>> print("Position after 1 step:", state_after_1_step.pos[0])
>>>
>>> # Advance by 10 steps
>>> state_after_10_steps, system_after_10_steps = jdem.System.step(state, system, n=10)
>>> print("Position after 10 steps:", state_after_10_steps.pos[0])
static stack(systems: Sequence[System]) System[source][source]#

Concatenates a sequence of System snapshots into a trajectory or batch along axis 0.

This method is useful for collecting simulation snapshots over time into a single System object where the leading dimension represents time or when preparing a batched system.

Parameters:

systems (Sequence[System]) – A sequence (e.g., list, tuple) of System instances to be stacked.

Returns:

A new System instance where each attribute is a JAX array with an additional leading dimension representing the stacked trajectory. For example, if input pos was (N, dim), output pos will be (T, N, dim).

Return type:

System