jaxdem.system#

Defines the simulation configuration and the tooling for driving the simulation.

Classes

System(integrator, collider, domain, ...)

Encapsulates the entire simulation configuration.

final class jaxdem.system.System(integrator: ~jaxdem.integrators.Integrator, collider: ~jaxdem.colliders.Collider, domain: ~jaxdem.domains.Domain, force_model: ~jaxdem.forces.ForceModel, mat_table: ~jaxdem.materials.materialTable.MaterialTable, dt: ~jax.Array, time: ~jax.Array, dim: ~jax.Array, step_count: ~jax.Array = <factory>)[source]#

Bases: object

Encapsulates the entire simulation configuration.

Notes

  • The System object is designed to be JIT-compiled for efficient execution.

  • The System class is a frozen dataclass, meaning its attributes cannot be changed after instantiation.

Example

Creating a basic 2D simulation system:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> # Create a System instance
>>> sim_system = jdem.System.create(
>>>     dim=state.dim,
>>>     dt=0.001,
>>>     integrator_type="euler",
>>>     collider_type="naive",
>>>     domain_type="free",
>>>     force_model_type="spring",
>>>     # You can pass keyword arguments to component constructors via '_kw' dicts
>>>     domain_kw=dict(box_size=jnp.array([5.0, 5.0]), anchor=jnp.array([0.0, 0.0]))
>>> )
>>>
>>> print(f"System integrator: {sim_system.integrator.__class__.__name__}")
>>> print(f"System force model: {sim_system.force_model.__class__.__name__}")
>>> print(f"Domain box size: {sim_system.domain.box_size}")
integrator: Integrator#

Instance of jaxdem.Integrator that advances the simulation state in time.

collider: Collider#

Instance of jaxdem.Collider that performs contact detection and computes inter-particle forces and potential energies.

domain: Domain#

Instance of jaxdem.Domain that defines the simulation boundaries, displacement rules, and boundary conditions.

force_model: ForceModel#

Instance of jaxdem.ForceModel that defines the physical laws for inter-particle interactions.

mat_table: MaterialTable#

Instance of jaxdem.MaterialTable holding material properties and pairwise interaction parameters.

dt: Array#

The global simulation time step \(\Delta t\).

time: Array#

Elapsed simulation time.

dim: Array#

Spatial dimension of the system.

step_count: Array#

Number of integration steps that have been performed.

static create(dim: int, *, dt: float = 0.005, time: float = 0.0, integrator_type: str = 'euler', collider_type: str = 'naive', domain_type: str = 'free', force_model_type: str = 'spring', mat_table: MaterialTable | None = None, integrator_kw: Dict[str, Any] | None = None, collider_kw: Dict[str, Any] | None = None, domain_kw: Dict[str, Any] | None = None, force_model_kw: Dict[str, Any] | None = None) System[source][source]#

Factory method to create a System instance with specified components.

Parameters:
  • dim (int) – The spatial dimension of the simulation (2 or 3).

  • dt (float, optional) – The global simulation time step.

  • integrator_type (str, optional) – The registered type string for the jaxdem.Integrator to use.

  • collider_type (str, optional) – The registered type string for the jaxdem.Collider to use.

  • domain_type (str, optional) – The registered type string for the jaxdem.Domain to use.

  • force_model_type (str, optional) – The registered type string for the jaxdem.ForceModel to use.

  • mat_table (MaterialTable or None, optional) – An optional pre-configured jaxdem.MaterialTable. If None, a default jaxdem.MaterialTable will be created with one generic elastic material and “harmonic” jaxdem.MaterialMatchmaker.

  • integrator_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected Integrator type.

  • collider_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected Collider type.

  • domain_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected Domain type.

  • force_model_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected ForceModel type.

Returns:

A fully configured System instance ready for simulation.

Return type:

System

Raises:
  • KeyError – If a specified *_type is not registered in its respective factory, or if the mat_table is missing properties required by the force_model.

  • TypeError – If constructor keyword arguments are invalid for any component.

  • AssertionError – If the domain_kw ‘box_size’ or ‘anchor’ shapes do not match the dim.

Example

Creating a 3D system with reflective boundaries and a custom dt:

>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> system_reflect = jdem.System.create(
>>>     dim=3,
>>>     dt=0.0005,
>>>     domain_type="reflect",
>>>     domain_kw=dict(box_size=jnp.array([20.0, 20.0, 20.0]), anchor=jnp.array([-10.0, -10.0, -10.0])),
>>>     force_model_type="spring",
>>> )
>>> print(f"System dt: {system_reflect.dt}")
>>> print(f"Domain type: {system_reflect.domain.__class__.__name__}")

Creating a system with a pre-defined MaterialTable:

>>> custom_mat_kw = dict(young=2.0e5, poisson=0.25)
>>> custom_material = jdem.Material.create("custom_mat", **custom_mat_kw)
>>> custom_mat_table = jdem.MaterialTable.from_materials(
...     [custom_material], matcher=jdem.MaterialMatchmaker.create("linear")
... )
>>>
>>> system_custom_mat = jdem.System.create(
...     dim=2,
...     mat_table=custom_mat_table,
...     force_model_type="spring"
... )
static trajectory_rollout(state: State, system: System, *, n: int, stride: int = 1, batched: bool = False) Tuple['State', 'System', Tuple['State', 'System']][source][source]#

Rolls the system forward for a specified number of frames, collecting a trajectory.

This method performs n * stride total simulation steps, but it saves the State and System every stride steps, returning a trajectory of n snapshots. This is highly efficient for data collection within JAX as it leverages jax.lax.scan.

stateState

The initial state of the simulation.

systemSystem

The initial system configuration.

nint

The number of frames (snapshots) to collect in the trajectory. This argument must be static.

strideint, optional

The number of integration steps to advance between each collected frame. Defaults to 1, meaning every step is collected. This argument must be static.

Tuple[State, System, Tuple[State, System]]

A tuple containing:

  • final_state:

    The State object at the end of the rollout.

  • final_system:

    The System object at the end of the rollout.

  • trajectory:

    A tuple of State and System objects, where each leaf array has an additional leading axis of size n representing the trajectory. The State and System objects within trajectory are structured as if created by jaxdem.State.stack() and a similar jaxdem.System stack.

ValueError

If n or stride are non-positive. (Implicit by jax.lax.scan length)

>>> import jaxdem as jdem
>>>
>>> state = jdem.utils.grid_state(n_per_axis=(1, 1), spacing=1.0, radius=0.1)
>>> system = jdem.System.create(dim=2, dt=0.01)
>>>
>>> # Roll out for 10 frames, saving every 5 steps
>>> final_state, final_system, traj = jdem.System.trajectory_rollout(
...     state, system, n=10, stride=5
... )
>>>
>>> print(f"Total simulation steps performed: {10 * 5}")
>>> print(f"Trajectory length (number of frames): {traj[0].pos.shape[0]}")  # traj[0] is the state part of the trajectory
>>> print(f"First frame position:
{traj[0].pos[0]}”)
>>> print(f"Last frame position:
{traj[0].pos[-1]}”)
>>> print(f"Final state position (should match last frame):

{final_state.pos}”)

static step(state: State, system: System, *, n: int = 1, batched: bool = False) Tuple['State', 'System'][source][source]#

Advances the simulation state by n time steps.

This method provides a convenient way to run multiple integration steps. For a single step (n=1), it directly calls the integrator’s step method. For multiple steps (n > 1), it uses an optimized internal loop based on jax.lax.scan to maintain JIT-compilation efficiency.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The current system configuration.

  • n (int, optional) – The number of integration steps to perform. Defaults to 1. This argument must be static.

Returns:

A tuple containing the final State and System after n integration steps.

Return type:

Tuple[State, System]

Raises:

ValueError – If n is non-positive. (Implicit by jax.lax.scan length check).

Example

>>> import jaxdem as jdem
>>>
>>> state = jdem.utils.grid_state(n_per_axis=(1, 1), spacing=1.0, radius=0.1)
>>> system = jdem.System.create(dim=2, dt=0.01)
>>>
>>> # Advance by 1 step
>>> state_after_1_step, system_after_1_step = jdem.System.step(state, system)
>>> print("Position after 1 step:", state_after_1_step.pos[0])
>>>
>>> # Advance by 10 steps
>>> state_after_10_steps, system_after_10_steps = jdem.System.step(state, system, n=10)
>>> print("Position after 10 steps:", state_after_10_steps.pos[0])