# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Defines the simulation configuration and the tooling for driving the simulation.
"""
from __future__ import annotations
import jax
import jax.numpy as jnp
from dataclasses import dataclass, field
from functools import partial
from typing import TYPE_CHECKING, final, Tuple, Optional, Dict, Any, Sequence
from .integrators import Integrator
from .colliders import Collider
from .domains import Domain
from .forces import ForceModel
from .materials import MaterialTable, Material
from .material_matchmakers import MaterialMatchmaker
if TYPE_CHECKING:
from .state import State
def _check_material_table(table, required: Sequence[str]):
"""
Checks if the provided MaterialTable contains all required properties for a given force model.
This helper function ensures that all material properties specified by a
:class:`ForceModel` (via :attr:`ForceModel.required_material_properties`)
are present as attributes in the given :class:`MaterialTable`.
Parameters
----------
table : MaterialTable
The material table instance to check.
required : Sequence[str]
A sequence of strings representing the names of material properties that are required by a specific force model.
Raises
------
KeyError
If the `MaterialTable` instance is missing any of the `required` material properties.
"""
missing = [k for k in required if not hasattr(table, k)]
if missing:
raise KeyError(
f"MaterialTable lacks fields {missing}, required by the selected force model."
)
[docs]
@final
@jax.tree_util.register_dataclass
@dataclass(slots=True, frozen=True)
class System:
"""
Encapsulates the entire simulation configuration.
Notes
-----
- The `System` object is designed to be JIT-compiled for efficient execution.
- The `System` class is a frozen dataclass, meaning its attributes cannot be changed after instantiation.
Example
-------
Creating a basic 2D simulation system:
>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> # Create a System instance
>>> sim_system = jdem.System.create(
>>> dim=state.dim,
>>> dt=0.001,
>>> integrator_type="euler",
>>> collider_type="naive",
>>> domain_type="free",
>>> force_model_type="spring",
>>> # You can pass keyword arguments to component constructors via '_kw' dicts
>>> domain_kw=dict(box_size=jnp.array([5.0, 5.0]), anchor=jnp.array([0.0, 0.0]))
>>> )
>>>
>>> print(f"System integrator: {sim_system.integrator.__class__.__name__}")
>>> print(f"System force model: {sim_system.force_model.__class__.__name__}")
>>> print(f"Domain box size: {sim_system.domain.box_size}")
"""
integrator: "Integrator"
"""Instance of :class:`jaxdem.Integrator` that advances the simulation state in time."""
collider: "Collider"
"""Instance of :class:`jaxdem.Collider` that performs contact detection and computes inter-particle forces and potential energies."""
domain: "Domain"
"""Instance of :class:`jaxdem.Domain` that defines the simulation boundaries, displacement rules, and boundary conditions."""
force_model: "ForceModel"
"""Instance of :class:`jaxdem.ForceModel` that defines the physical laws for inter-particle interactions."""
mat_table: "MaterialTable"
"""Instance of :class:`jaxdem.MaterialTable` holding material properties and pairwise interaction parameters."""
dt: jax.Array
r"""The global simulation time step :math:`\Delta t`."""
time: jax.Array
"""Elapsed simulation time."""
dim: jax.Array
"""Spatial dimension of the system."""
step_count: jax.Array = field(default_factory=lambda: jnp.asarray(0, dtype=int))
"""Number of integration steps that have been performed."""
[docs]
@staticmethod
def create(
dim: int,
*,
dt: float = 0.005,
time: float = 0.0,
integrator_type: str = "euler",
collider_type: str = "naive",
domain_type: str = "free",
force_model_type: str = "spring",
mat_table: Optional["MaterialTable"] = None,
integrator_kw: Optional[Dict[str, Any]] = None,
collider_kw: Optional[Dict[str, Any]] = None,
domain_kw: Optional[Dict[str, Any]] = None,
force_model_kw: Optional[Dict[str, Any]] = None,
) -> "System":
"""
Factory method to create a :class:`System` instance with specified components.
Parameters
----------
dim : int
The spatial dimension of the simulation (2 or 3).
dt : float, optional
The global simulation time step.
integrator_type : str, optional
The registered type string for the :class:`jaxdem.Integrator` to use.
collider_type : str, optional
The registered type string for the :class:`jaxdem.Collider` to use.
domain_type : str, optional
The registered type string for the :class:`jaxdem.Domain` to use.
force_model_type : str, optional
The registered type string for the :class:`jaxdem.ForceModel` to use.
mat_table : MaterialTable or None, optional
An optional pre-configured :class:`jaxdem.MaterialTable`. If `None`, a
default `jaxdem.MaterialTable` will be created with one generic elastic material and "harmonic" `jaxdem.MaterialMatchmaker`.
integrator_kw : Dict[str, Any] or None, optional
Keyword arguments to pass to the constructor of the selected `Integrator` type.
collider_kw : Dict[str, Any] or None, optional
Keyword arguments to pass to the constructor of the selected `Collider` type.
domain_kw : Dict[str, Any] or None, optional
Keyword arguments to pass to the constructor of the selected `Domain` type.
force_model_kw : Dict[str, Any] or None, optional
Keyword arguments to pass to the constructor of the selected `ForceModel` type.
Returns
-------
System
A fully configured `System` instance ready for simulation.
Raises
------
KeyError
If a specified `*_type` is not registered in its respective factory,
or if the `mat_table` is missing properties required by the `force_model`.
TypeError
If constructor keyword arguments are invalid for any component.
AssertionError
If the `domain_kw` 'box_size' or 'anchor' shapes do not match the `dim`.
Example
-------
Creating a 3D system with reflective boundaries and a custom `dt`:
>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> system_reflect = jdem.System.create(
>>> dim=3,
>>> dt=0.0005,
>>> domain_type="reflect",
>>> domain_kw=dict(box_size=jnp.array([20.0, 20.0, 20.0]), anchor=jnp.array([-10.0, -10.0, -10.0])),
>>> force_model_type="spring",
>>> )
>>> print(f"System dt: {system_reflect.dt}")
>>> print(f"Domain type: {system_reflect.domain.__class__.__name__}")
Creating a system with a pre-defined MaterialTable:
>>> custom_mat_kw = dict(young=2.0e5, poisson=0.25)
>>> custom_material = jdem.Material.create("custom_mat", **custom_mat_kw)
>>> custom_mat_table = jdem.MaterialTable.from_materials(
... [custom_material], matcher=jdem.MaterialMatchmaker.create("linear")
... )
>>>
>>> system_custom_mat = jdem.System.create(
... dim=2,
... mat_table=custom_mat_table,
... force_model_type="spring"
... )
"""
integrator_kw = {} if integrator_kw is None else dict(integrator_kw)
collider_kw = {} if collider_kw is None else dict(collider_kw)
force_model_kw = {} if force_model_kw is None else dict(force_model_kw)
if domain_kw is None:
domain_kw = {
"box_size": jnp.ones(dim, dtype=float),
"anchor": jnp.zeros(dim, dtype=float),
}
domain_kw = dict(domain_kw)
if mat_table is None:
matcher = MaterialMatchmaker.create("harmonic")
mat_table = MaterialTable.from_materials(
[Material.create("elastic", young=1.0e4, poisson=0.3)], matcher=matcher
)
force_model = ForceModel.create(force_model_type, **force_model_kw)
_check_material_table(mat_table, force_model.required_material_properties)
return System(
integrator=Integrator.create(integrator_type, **integrator_kw),
collider=Collider.create(collider_type, **collider_kw),
domain=Domain.create(domain_type, dim=dim, **domain_kw),
force_model=force_model,
mat_table=mat_table,
dim=jnp.asarray(dim, dtype=int),
dt=jnp.asarray(dt, dtype=float),
time=jnp.asarray(time, dtype=float),
)
@staticmethod
@partial(jax.jit, static_argnames=("n"), donate_argnames=("state", "system"))
def _steps(state: "State", system: "System", n: int) -> Tuple["State", "System"]:
"""
Internal method to advance the simulation state by multiple steps using `jax.lax.scan`.
This function is an optimized JIT-compiled loop for performing `n` integration
steps without re-entering Python between steps.
Parameters
----------
state : State
The current state of the simulation.
system : System
The current system configuration.
n : int
The number of integration steps to perform. This argument must be static.
Returns
-------
Tuple[State, System]
A tuple containing the final `State` and `System` after `n` integration steps.
"""
@partial(jax.jit, donate_argnames=("carry"))
def body(carry, _):
st, sys = carry
return sys.integrator.step(st, sys), None
(state, system), _ = jax.lax.scan(body, (state, system), xs=None, length=n)
return state, system
[docs]
@staticmethod
@partial(
jax.jit,
static_argnames=("n", "stride", "batched"),
donate_argnames=("state", "system"),
)
def trajectory_rollout(
state: "State",
system: "System",
*,
n: int,
stride: int = 1,
batched: bool = False,
) -> Tuple["State", "System", Tuple["State", "System"]]:
"""
Rolls the system forward for a specified number of frames, collecting a trajectory.
This method performs `n * stride` total simulation steps, but it saves
the `State` and `System` every `stride` steps, returning a trajectory
of `n` snapshots. This is highly efficient for data collection within JAX
as it leverages `jax.lax.scan`.
Parameters
----------
state : State
The initial state of the simulation.
system : System
The initial system configuration.
n : int
The number of frames (snapshots) to collect in the trajectory. This argument must be static.
stride : int, optional
The number of integration steps to advance between each collected frame. Defaults to 1, meaning every step is collected. This argument must be static.
Returns
-------
Tuple[State, System, Tuple[State, System]]
A tuple containing:
- `final_state`:
The `State` object at the end of the rollout.
- `final_system`:
The `System` object at the end of the rollout.
- `trajectory`:
A tuple of `State` and `System` objects, where each leaf array has an additional leading axis of size `n` representing the trajectory. The `State` and `System` objects within `trajectory` are structured as if created by :meth:`jaxdem.State.stack` and a similar :class:`jaxdem.System` stack.
Raises
------
ValueError
If `n` or `stride` are non-positive. (Implicit by `jax.lax.scan` length)
Example
-------
>>> import jaxdem as jdem
>>>
>>> state = jdem.utils.grid_state(n_per_axis=(1, 1), spacing=1.0, radius=0.1)
>>> system = jdem.System.create(dim=2, dt=0.01)
>>>
>>> # Roll out for 10 frames, saving every 5 steps
>>> final_state, final_system, traj = jdem.System.trajectory_rollout(
... state, system, n=10, stride=5
... )
>>>
>>> print(f"Total simulation steps performed: {10 * 5}")
>>> print(f"Trajectory length (number of frames): {traj[0].pos.shape[0]}") # traj[0] is the state part of the trajectory
>>> print(f"First frame position:\n{traj[0].pos[0]}")
>>> print(f"Last frame position:\n{traj[0].pos[-1]}")
>>> print(f"Final state position (should match last frame):\n{final_state.pos}")
"""
@partial(jax.jit, donate_argnames=("carry"))
def body(carry, _):
st, sys = carry
carry = sys._steps(st, sys, stride)
return carry, carry
if batched:
body = jax.vmap(body, in_axes=(0, None))
(state, system), traj = jax.lax.scan(body, (state, system), xs=None, length=n)
return state, system, traj
[docs]
@staticmethod
@partial(
jax.jit,
static_argnames=("n", "batched"),
donate_argnames=("state", "system"),
)
def step(
state: "State",
system: "System",
*,
n: int = 1,
batched: bool = False,
) -> Tuple["State", "System"]:
"""
Advances the simulation state by `n` time steps.
This method provides a convenient way to run multiple integration steps.
For a single step (`n=1`), it directly calls the integrator's step method.
For multiple steps (`n > 1`), it uses an optimized internal loop based on
`jax.lax.scan` to maintain JIT-compilation efficiency.
Parameters
----------
state : State
The current state of the simulation.
system : System
The current system configuration.
n : int, optional
The number of integration steps to perform. Defaults to 1. This argument must be static.
Returns
-------
Tuple[State, System]
A tuple containing the final `State` and `System` after `n` integration steps.
Raises
------
ValueError
If `n` is non-positive. (Implicit by `jax.lax.scan` length check).
Example
-------
>>> import jaxdem as jdem
>>>
>>> state = jdem.utils.grid_state(n_per_axis=(1, 1), spacing=1.0, radius=0.1)
>>> system = jdem.System.create(dim=2, dt=0.01)
>>>
>>> # Advance by 1 step
>>> state_after_1_step, system_after_1_step = jdem.System.step(state, system)
>>> print("Position after 1 step:", state_after_1_step.pos[0])
>>>
>>> # Advance by 10 steps
>>> state_after_10_steps, system_after_10_steps = jdem.System.step(state, system, n=10)
>>> print("Position after 10 steps:", state_after_10_steps.pos[0])
"""
body = system._steps
if batched:
body = jax.vmap(body, in_axes=(0, 0, None))
return body(state, system, n)