jaxdem.system#

Defines the simulation configuration and the tooling for driving the simulation.

Classes

System(integrator, collider, domain, ...)

Encapsulates the entire simulation configuration.

final class jaxdem.system.System(integrator: ~jaxdem.integrator.Integrator, collider: ~jaxdem.collider.Collider, domain: ~jaxdem.domain.Domain, force_model: ~jaxdem.forces.ForceModel, mat_table: ~jaxdem.material.MaterialTable, dt: ~jax.Array, step_count: ~jax.Array = <factory>)[source][source]#

Bases: object

Encapsulates the entire simulation configuration.

Notes

  • The System object is designed to be JIT-compiled for efficient execution.

  • The System class is a frozen dataclass, meaning its attributes cannot be changed after instantiation.

Example

Creating a basic 2D simulation system:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> # Create a System instance
>>> sim_system = jdem.System.create(
>>>     dim=state.dim,
>>>     dt=0.001,
>>>     integrator_type="euler",
>>>     collider_type="naive",
>>>     domain_type="free",
>>>     force_model_type="spring",
>>>     # You can pass keyword arguments to component constructors via '_kw' dicts
>>>     domain_kw=dict(box_size=jnp.array([5.0, 5.0]), anchor=jnp.array([0.0, 0.0]))
>>> )
>>>
>>> print(f"System integrator: {sim_system.integrator.__class__.__name__}")
>>> print(f"System force model: {sim_system.force_model.__class__.__name__}")
>>> print(f"Domain box size: {sim_system.domain.box_size}")
integrator: Integrator#

Instance of Integrator that defines how the simulation state is advanced in time

collider: Collider#

Instance of Collider that performs contact detection and computes inter-particle forces and potential energies.

domain: Domain#

Instance of Domain that defines the simulation boundaries, how displacement vectors are calculated, and how boundary conditions are applied

force_model: ForceModel#

Instance of ForceModel that defines the specific physical laws for inter-particle interactions.

mat_table: MaterialTable#

Instance of MaterialTable holding material properties and their effective interaction parameters for pairs of materials.

dt: Array#

The global simulation time step \(\Delta t\).

step_count: Array#

Counts the number of steps that have been performed.

static create(dim: int, *, dt: float = 0.01, integrator_type: str = 'euler', collider_type: str = 'naive', domain_type: str = 'free', force_model_type: str = 'spring', mat_table: MaterialTable | None = None, integrator_kw: Dict[str, Any] | None = None, collider_kw: Dict[str, Any] | None = None, domain_kw: Dict[str, Any] | None = None, force_model_kw: Dict[str, Any] | None = None) System[source][source]#

Factory method to create a System instance with specified components.

Parameters:
  • dim (int) – The spatial dimension of the simulation (2 or 3).

  • dt (float, optional) – The global simulation time step.

  • integrator_type (str, optional) – The registered type string for the Integrator to use.

  • collider_type (str, optional) – The registered type string for the Collider to use.

  • domain_type (str, optional) – The registered type string for the Domain to use.

  • force_model_type (str, optional) – The registered type string for the ForceModel to use.

  • mat_table (MaterialTable or None, optional) – An optional pre-configured MaterialTable. If None, a default MaterialTable will be created with one generic elastic material and “harmonic” MaterialMatchmaker.

  • integrator_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected Integrator type.

  • collider_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected Collider type.

  • domain_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected Domain type.

  • force_model_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected ForceModel type.

Returns:

A fully configured System instance ready for simulation.

Return type:

System

Raises:
  • KeyError – If a specified *_type is not registered in its respective factory, or if the mat_table is missing properties required by the force_model.

  • TypeError – If constructor keyword arguments are invalid for any component.

  • AssertionError – If the domain_kw ‘box_size’ or ‘anchor’ shapes do not match the dim.

Example

Creating a 3D system with reflective boundaries and a custom dt:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> system_reflect = jdem.System.create(
>>>     dim=3,
>>>     dt=0.0005,
>>>     domain_type="reflect",
>>>     domain_kw=dict(box_size=jnp.array([20.0, 20.0, 20.0]), anchor=jnp.array([-10.0, -10.0, -10.0])),
>>>     force_model_type="spring",
>>> )
>>> print(f"System dt: {system_reflect.dt}")
>>> print(f"Domain type: {system_reflect.domain.__class__.__name__}")

Creating a system with a pre-defined MaterialTable:

>>> custom_material = Material.create("custom_mat", **custom_mat_kw)
>>> custom_mat_table = MaterialTable.from_materials([custom_material], matcher=jdem.MaterialMatchmaker.create("linear"))
>>>
>>> system_custom_mat = jdem.System.create(
>>>     dim=2,
>>>     mat_table=custom_mat_table,
>>>     force_model_type="spring"
>>> )
static trajectory_rollout(state: State, system: System, n: int, stride: int = 1) Tuple[State, System, Tuple[State, System]][source][source]#

Rolls the system forward for a specified number of frames, collecting a trajectory.

This method performs n * stride total simulation steps, but it saves the State and System every stride steps, returning a trajectory of n snapshots. This is highly efficient for data collection within JAX as it leverages jax.lax.scan.

Parameters:
  • state (State) – The initial state of the simulation.

  • system (System) – The initial system configuration.

  • n (int) – The number of frames (snapshots) to collect in the trajectory. This argument must be static.

  • stride (int, optional) – The number of integration steps to advance between each collected frame. Defaults to 1, meaning every step is collected. This argument must be static.

Returns:

A tuple containing:

  • final_state:

    The State object at the end of the rollout.

  • final_system:

    The System object at the end of the rollout.

  • trajectory:

    A tuple of State and System objects, where each leaf array has an additional leading axis of size n representing the trajectory. The State and System objects within trajectory are structured as if created by State.stack() and a similar System stack.

Return type:

Tuple[State, System, Tuple[State, System]]

Raises:

ValueError – If n or stride are non-positive. (Implicit by jax.lax.scan length)

Example

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> state = jdem.utils.grid_state(n_per_axis=(1,1), spacing=1.0, radius=0.1)
>>> system = jdem.System.create(dim=2, dt=0.01, state=state)
>>>
>>> # Rollout for 10 frames, saving every 5 steps
>>> final_state, final_system, traj = system.trajectory_rollout(
>>>     state, system, n=10, stride=5
>>> )
>>>
>>> print(f"Total simulation steps performed: {10 * 5}")
>>> print(f"Trajectory length (number of frames): {traj[0].pos.shape[0]}") # traj[0] is the state part of trajectory
>>> print(f"First frame position:\n{traj[0].pos[0]}")
>>> print(f"Last frame position:\n{traj[0].pos[-1]}")
>>> print(f"Final state position (should match last frame):\n{final_state.pos}")
static step(state: State, system: System, n: int = 1) Tuple[State, System][source][source]#

Advances the simulation state by n time steps.

This method provides a convenient way to run multiple integration steps. For a single step (n=1), it directly calls the integrator’s step method. For multiple steps (n > 1), it uses an optimized internal loop based on jax.lax.scan to maintain JIT-compilation efficiency.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The current system configuration.

  • n (int, optional) – The number of integration steps to perform. Defaults to 1. This argument must be static.

Returns:

A tuple containing the final State and System after n integration steps.

Return type:

Tuple[State, System]

Raises:

ValueError – If n is non-positive. (Implicit by jax.lax.scan length check).

Example

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> state = jdem.utils.grid_state(n_per_axis=(1,1), spacing=1.0, radius=0.1)
>>> system = jdem.System.create(dim=2, dt=0.01, state=state)
>>>
>>> # Advance by 1 step
>>> state_after_1_step, system_after_1_step = system.step(system.state, system)
>>> print("Position after 1 step:", state_after_1_step.pos[0])
>>>
>>> # Advance by 10 steps
>>> state_after_10_steps, system_after_10_steps = system.step(system.state, system, n=10)
>>> print("Position after 10 steps:", state_after_10_steps.pos[0])