# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Defines the simulation configuration and the tooling for driving the simulation.
"""
import jax
import jax.numpy as jnp
from dataclasses import dataclass, field, replace
from functools import partial
from typing import final, Tuple, Optional, Dict, Any, Sequence
from .integrator import Integrator
from .collider import Collider
from .domain import Domain
from .forces import ForceModel
from .material import MaterialTable, Material
from .materialMatchmaker import MaterialMatchmaker
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .state import State
def _check_material_table(table, required: Sequence[str]):
"""
Checks if the provided MaterialTable contains all required properties for a given force model.
This helper function ensures that all material properties specified by a
:class:`ForceModel` (via :attr:`ForceModel.required_material_properties`)
are present as attributes in the given :class:`MaterialTable`.
Parameters
----------
table : MaterialTable
The material table instance to check.
required : Sequence[str]
A sequence of strings representing the names of material properties that are required by a specific force model.
Raises
------
KeyError
If the `MaterialTable` instance is missing any of the `required` material properties.
"""
missing = [k for k in required if not hasattr(table, k)]
if missing:
raise KeyError(
f"MaterialTable lacks fields {missing}, required by the selected force model."
)
[docs]
@final
@jax.tree_util.register_dataclass
@dataclass(slots=True, frozen=True)
class System:
"""
Encapsulates the entire simulation configuration.
Notes
-----
- The `System` object is designed to be JIT-compiled for efficient execution.
- The `System` class is a frozen dataclass, meaning its attributes cannot be changed after instantiation.
Example
-------
Creating a basic 2D simulation system:
>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> # Create a System instance
>>> sim_system = jdem.System.create(
>>> dim=state.dim,
>>> dt=0.001,
>>> integrator_type="euler",
>>> collider_type="naive",
>>> domain_type="free",
>>> force_model_type="spring",
>>> # You can pass keyword arguments to component constructors via '_kw' dicts
>>> domain_kw=dict(box_size=jnp.array([5.0, 5.0]), anchor=jnp.array([0.0, 0.0]))
>>> )
>>>
>>> print(f"System integrator: {sim_system.integrator.__class__.__name__}")
>>> print(f"System force model: {sim_system.force_model.__class__.__name__}")
>>> print(f"Domain box size: {sim_system.domain.box_size}")
"""
integrator: "Integrator"
"""
Instance of :class:`Integrator` that defines how the simulation state is advanced in time
"""
collider: "Collider"
"""
Instance of :class:`Collider` that performs contact detection and computes inter-particle forces and potential energies.
"""
domain: "Domain"
"""
Instance of :class:`Domain` that defines the simulation boundaries, how displacement vectors are calculated, and how boundary conditions are applied
"""
force_model: "ForceModel"
"""
Instance of :class:`ForceModel` that defines the specific physical laws for inter-particle interactions.
"""
mat_table: "MaterialTable"
"""
Instance of :class:`MaterialTable` holding material properties and their effective interaction parameters for pairs of materials.
"""
dt: jax.Array
"""
The global simulation time step :math:`\\Delta t`.
"""
step_count: jax.Array = field(default_factory=lambda: jnp.asarray(0, dtype=int))
"""
Counts the number of steps that have been performed.
"""
[docs]
@staticmethod
def create(
dim: int,
*,
dt: float = 0.01,
integrator_type: str = "euler",
collider_type: str = "naive",
domain_type: str = "free",
force_model_type: str = "spring",
mat_table: Optional["MaterialTable"] = None,
integrator_kw: Optional[Dict[str, Any]] = None,
collider_kw: Optional[Dict[str, Any]] = None,
domain_kw: Optional[Dict[str, Any]] = None,
force_model_kw: Optional[Dict[str, Any]] = None,
) -> "System":
"""
Factory method to create a :class:`System` instance with specified components.
Parameters
----------
dim : int
The spatial dimension of the simulation (2 or 3).
dt : float, optional
The global simulation time step.
integrator_type : str, optional
The registered type string for the :class:`Integrator` to use.
collider_type : str, optional
The registered type string for the :class:`Collider` to use.
domain_type : str, optional
The registered type string for the :class:`Domain` to use.
force_model_type : str, optional
The registered type string for the :class:`ForceModel` to use.
mat_table : MaterialTable or None, optional
An optional pre-configured :class:`MaterialTable`. If `None`, a
default `MaterialTable` will be created with one generic elastic material and "harmonic" `MaterialMatchmaker`.
integrator_kw : Dict[str, Any] or None, optional
Keyword arguments to pass to the constructor of the selected `Integrator` type.
collider_kw : Dict[str, Any] or None, optional
Keyword arguments to pass to the constructor of the selected `Collider` type.
domain_kw : Dict[str, Any] or None, optional
Keyword arguments to pass to the constructor of the selected `Domain` type.
force_model_kw : Dict[str, Any] or None, optional
Keyword arguments to pass to the constructor of the selected `ForceModel` type.
Returns
-------
System
A fully configured `System` instance ready for simulation.
Raises
------
KeyError
If a specified `*_type` is not registered in its respective factory,
or if the `mat_table` is missing properties required by the `force_model`.
TypeError
If constructor keyword arguments are invalid for any component.
AssertionError
If the `domain_kw` 'box_size' or 'anchor' shapes do not match the `dim`.
Example
-------
Creating a 3D system with reflective boundaries and a custom `dt`:
>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> system_reflect = jdem.System.create(
>>> dim=3,
>>> dt=0.0005,
>>> domain_type="reflect",
>>> domain_kw=dict(box_size=jnp.array([20.0, 20.0, 20.0]), anchor=jnp.array([-10.0, -10.0, -10.0])),
>>> force_model_type="spring",
>>> )
>>> print(f"System dt: {system_reflect.dt}")
>>> print(f"Domain type: {system_reflect.domain.__class__.__name__}")
Creating a system with a pre-defined MaterialTable:
>>> custom_material = Material.create("custom_mat", **custom_mat_kw)
>>> custom_mat_table = MaterialTable.from_materials([custom_material], matcher=jdem.MaterialMatchmaker.create("linear"))
>>>
>>> system_custom_mat = jdem.System.create(
>>> dim=2,
>>> mat_table=custom_mat_table,
>>> force_model_type="spring"
>>> )
"""
integrator_kw = {} if integrator_kw is None else dict(integrator_kw)
collider_kw = {} if collider_kw is None else dict(collider_kw)
force_model_kw = {} if force_model_kw is None else dict(force_model_kw)
if domain_kw is None:
domain_kw = {
"box_size": jnp.ones(dim, dtype=float),
"anchor": jnp.zeros(dim, dtype=float),
}
else:
domain_kw = dict(domain_kw)
missing = [k for k in ("box_size", "anchor") if k not in domain_kw]
for miss in missing:
domain_kw[miss] = {
"box_size": jnp.ones(dim, dtype=float),
"anchor": jnp.zeros(dim, dtype=float),
}[miss]
if mat_table is None:
matcher = MaterialMatchmaker.create("linear")
mat_table = MaterialTable.from_materials(
[Material.create("elastic", young=1.0e4, poisson=0.3)], matcher=matcher
)
force_model = ForceModel.create(force_model_type, **force_model_kw)
domain_kw["box_size"] = jnp.asarray(domain_kw["box_size"], dtype=float)
domain_kw["anchor"] = jnp.asarray(domain_kw["anchor"], dtype=float)
assert domain_kw["box_size"].shape == (
dim,
), f"box_size={domain_kw['box_size'].shape} shape must match dimension={(dim,)}"
assert domain_kw["anchor"].shape == (
dim,
), f"anchor={domain_kw['anchor'].shape} shape must match dimension={(dim,)}"
_check_material_table(mat_table, force_model.required_material_properties)
return System(
integrator=Integrator.create(integrator_type, **integrator_kw),
collider=Collider.create(collider_type, **collider_kw),
domain=Domain.create(domain_type, **domain_kw),
force_model=force_model,
mat_table=mat_table,
dt=jnp.asarray(dt, dtype=float),
)
@staticmethod
@partial(jax.jit, static_argnames=("n"))
def _steps(state: "State", system: "System", n: int) -> Tuple["State", "System"]:
"""
Internal method to advance the simulation state by multiple steps using `jax.lax.scan`.
This function is an optimized JIT-compiled loop for performing `n` integration
steps without re-entering Python between steps.
Parameters
----------
state : State
The current state of the simulation.
system : System
The current system configuration.
n : int
The number of integration steps to perform. This argument must be static.
Returns
-------
Tuple[State, System]
A tuple containing the final `State` and `System` after `n` integration steps.
"""
def body(carry, _):
st, sys = carry
st, sys = sys.integrator.step(st, sys)
return (st, sys), None
(final_state, final_system), _ = jax.lax.scan(
body, (state, system), xs=None, length=n
)
return final_state, final_system
[docs]
@staticmethod
@partial(jax.jit, static_argnames=("n", "stride"))
def trajectory_rollout(
state: "State", system: "System", n: int, stride: int = 1
) -> Tuple["State", "System", Tuple["State", "System"]]:
"""
Rolls the system forward for a specified number of frames, collecting a trajectory.
This method performs `n * stride` total simulation steps, but it saves
the `State` and `System` every `stride` steps, returning a trajectory
of `n` snapshots. This is highly efficient for data collection within JAX
as it leverages `jax.lax.scan`.
Parameters
----------
state : State
The initial state of the simulation.
system : System
The initial system configuration.
n : int
The number of frames (snapshots) to collect in the trajectory. This argument must be static.
stride : int, optional
The number of integration steps to advance between each collected frame. Defaults to 1, meaning every step is collected. This argument must be static.
Returns
-------
Tuple[State, System, Tuple[State, System]]
A tuple containing:
- `final_state`:
The `State` object at the end of the rollout.
- `final_system`:
The `System` object at the end of the rollout.
- `trajectory`:
A tuple of `State` and `System` objects, where each leaf array has an additional leading axis of size `n` representing the trajectory. The `State` and `System` objects within `trajectory` are structured as if created by :meth:`State.stack` and a similar `System` stack.
Raises
------
ValueError
If `n` or `stride` are non-positive. (Implicit by `jax.lax.scan` length)
Example
-------
>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> state = jdem.utils.grid_state(n_per_axis=(1,1), spacing=1.0, radius=0.1)
>>> system = jdem.System.create(dim=2, dt=0.01, state=state)
>>>
>>> # Rollout for 10 frames, saving every 5 steps
>>> final_state, final_system, traj = system.trajectory_rollout(
>>> state, system, n=10, stride=5
>>> )
>>>
>>> print(f"Total simulation steps performed: {10 * 5}")
>>> print(f"Trajectory length (number of frames): {traj[0].pos.shape[0]}") # traj[0] is the state part of trajectory
>>> print(f"First frame position:\\n{traj[0].pos[0]}")
>>> print(f"Last frame position:\\n{traj[0].pos[-1]}")
>>> print(f"Final state position (should match last frame):\\n{final_state.pos}")
"""
def body(carry, _):
st, sys = carry
st, sys = sys.step(st, sys, stride)
return (st, sys), (st, sys)
(final_state, final_system), traj = jax.lax.scan(
body, (state, system), xs=None, length=n
)
return final_state, final_system, traj
[docs]
@staticmethod
@partial(jax.jit, static_argnames=("n"))
def step(state: "State", system: "System", n: int = 1) -> Tuple["State", "System"]:
"""
Advances the simulation state by `n` time steps.
This method provides a convenient way to run multiple integration steps.
For a single step (`n=1`), it directly calls the integrator's step method.
For multiple steps (`n > 1`), it uses an optimized internal loop based on
`jax.lax.scan` to maintain JIT-compilation efficiency.
Parameters
----------
state : State
The current state of the simulation.
system : System
The current system configuration.
n : int, optional
The number of integration steps to perform. Defaults to 1. This argument must be static.
Returns
-------
Tuple[State, System]
A tuple containing the final `State` and `System` after `n` integration steps.
Raises
------
ValueError
If `n` is non-positive. (Implicit by `jax.lax.scan` length check).
Example
-------
>>> import jaxdem as jdem
>>> import jax.numpy as jnp
>>>
>>> state = jdem.utils.grid_state(n_per_axis=(1,1), spacing=1.0, radius=0.1)
>>> system = jdem.System.create(dim=2, dt=0.01, state=state)
>>>
>>> # Advance by 1 step
>>> state_after_1_step, system_after_1_step = system.step(system.state, system)
>>> print("Position after 1 step:", state_after_1_step.pos[0])
>>>
>>> # Advance by 10 steps
>>> state_after_10_steps, system_after_10_steps = system.step(system.state, system, n=10)
>>> print("Position after 10 steps:", state_after_10_steps.pos[0])
"""
system = replace(system, step_count=system.step_count + n)
return (
system.integrator.step(state, system)
if n == 1
else system._steps(state, system, n)
)