jaxdem.system#
Defines the simulation configuration and the tooling for driving the simulation.
Classes
|
Encapsulates the entire simulation configuration. |
- final class jaxdem.system.System(linear_integrator: LinearIntegrator | LinearMinimizer, rotation_integrator: RotationIntegrator | RotationMinimizer, collider: Collider, domain: Domain, force_manager: ForceManager, force_model: ForceModel, mat_table: MaterialTable, dt: jax.Array, time: jax.Array, dim: jax.Array, step_count: jax.Array, key: jax.Array)[source]#
Bases:
objectEncapsulates the entire simulation configuration.
Notes
The System object is designed to be JIT-compiled for efficient execution.
The System dataclass is compatible with
jax.jit(), so every field should remain JAX arrays for best performance.
Example
Creating a basic 2D simulation system:
>>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> # Create a System instance >>> sim_system = jdem.System.create( >>> state_shape=state.shape, >>> dt=0.001, >>> linear_integrator_type="euler", >>> rotation_integrator_type="spiral", >>> collider_type="naive", >>> domain_type="free", >>> force_model_type="spring", >>> # You can pass keyword arguments to component constructors via '_kw' dicts >>> domain_kw=dict(box_size=jnp.array([5.0, 5.0]), anchor=jnp.array([0.0, 0.0])) >>> ) >>> >>> print(f"System integrator: {sim_system.linear_integrator.__class__.__name__}") >>> print(f"System force model: {sim_system.force_model.__class__.__name__}") >>> print(f"Domain box size: {sim_system.domain.box_size}")
- linear_integrator: LinearIntegrator | LinearMinimizer#
Instance of
jaxdem.LinearIntegratorthat advances the simulation linear state in time.
- rotation_integrator: RotationIntegrator | RotationMinimizer#
Instance of
jaxdem.RotationIntegratorthat advances the simulation angular state in time.
- collider: Collider#
Instance of
jaxdem.Colliderthat performs contact detection and computes inter-particle forces and potential energies.
- domain: Domain#
Instance of
jaxdem.Domainthat defines the simulation boundaries, displacement rules, and boundary conditions.
- force_manager: ForceManager#
Instance of
jaxdem.ForceManagerthat handles per particle forces like external forces and resets forces.
- force_model: ForceModel#
Instance of
jaxdem.ForceModelthat defines the physical laws for inter-particle interactions.
- mat_table: MaterialTable#
Instance of
jaxdem.MaterialTableholding material properties and pairwise interaction parameters.
- dt: jax.Array#
The global simulation time step \(\Delta t\).
- time: jax.Array#
Elapsed simulation time.
- dim: jax.Array#
Spatial dimension of the system.
- step_count: jax.Array#
Number of integration steps that have been performed.
- key: jax.Array#
PRNG key supporting stochastic functionality. Always update using split to ensure new numbers are generated.
- static create(state_shape: Tuple[int, ...], *, dt: float = 0.005, time: float = 0.0, linear_integrator_type: str = 'verlet', rotation_integrator_type: str = 'verletspiral', collider_type: str = 'naive', domain_type: str = 'free', force_model_type: str = 'spring', force_manager_kw: Dict[str, Any] | None = None, mat_table: MaterialTable | None = None, linear_integrator_kw: Dict[str, Any] | None = None, rotation_integrator_kw: Dict[str, Any] | None = None, collider_kw: Dict[str, Any] | None = None, domain_kw: Dict[str, Any] | None = None, force_model_kw: Dict[str, Any] | None = None, seed: int = 0, key: Array | None = None) System[source][source]#
Factory method to create a
Systeminstance with specified components.- Parameters:
state_shape (Tuple) – Shape of the state tensors handled by the simulation. The penultimate dimension corresponds to the number of particles
Nand the last dimension corresponds to the spatial dimensiondim.dt (float, optional) – The global simulation time step.
linear_integrator_type (str, optional) – The registered type string for the
jaxdem.integrators.LinearIntegratorused to evolve translational degrees of freedom.rotation_integrator_type (str, optional) – The registered type string for the
jaxdem.integrators.RotationIntegratorused to evolve angular degrees of freedom.collider_type (str, optional) – The registered type string for the
jaxdem.Colliderto use.domain_type (str, optional) – The registered type string for the
jaxdem.Domainto use.force_model_type (str, optional) – The registered type string for the
jaxdem.ForceModelto use.force_manager_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of ForceManager.
mat_table (MaterialTable or None, optional) – An optional pre-configured
jaxdem.MaterialTable. If None, a default jaxdem.MaterialTable will be created with one generic elastic material and “harmonic” jaxdem.MaterialMatchmaker.linear_integrator_kw (Dict[str, Any] or None, optional) – Keyword arguments forwarded to the constructor of the selected LinearIntegrator type.
rotation_integrator_kw (Dict[str, Any] or None, optional) – Keyword arguments forwarded to the constructor of the selected RotationIntegrator type.
collider_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected Collider type.
domain_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected Domain type.
force_model_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected ForceModel type.
seed (int, optional) – Integer seed used for random number generation. Defaults to 0.
key (jax.Array, optional) – Key used for the jax random number generation. Defaults to None, shadowed by seed.
- Returns:
A fully configured System instance ready for simulation.
- Return type:
- Raises:
KeyError – If a specified *_type is not registered in its respective factory, or if the mat_table is missing properties required by the force_model.
TypeError – If constructor keyword arguments are invalid for any component.
AssertionError – If the domain_kw ‘box_size’ or ‘anchor’ shapes do not match the dim.
Example
Creating a 3D system with reflective boundaries and a custom dt:
>>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> system_reflect = jdem.System.create( >>> state_shape=(N, 3), >>> dt=0.0005, >>> domain_type="reflect", >>> domain_kw=dict(box_size=jnp.array([20.0, 20.0, 20.0]), anchor=jnp.array([-10.0, -10.0, -10.0])), >>> force_model_type="spring", >>> ) >>> print(f"System dt: {system_reflect.dt}") >>> print(f"Domain type: {system_reflect.domain.__class__.__name__}")
Creating a system with a pre-defined MaterialTable:
>>> custom_mat_kw = dict(young=2.0e5, poisson=0.25) >>> custom_material = jdem.Material.create("custom_mat", **custom_mat_kw) >>> custom_mat_table = jdem.MaterialTable.from_materials( ... [custom_material], matcher=jdem.MaterialMatchmaker.create("linear") ... ) >>> >>> system_custom_mat = jdem.System.create( ... state_shape=(N, 2), ... mat_table=custom_mat_table, ... force_model_type="spring" ... )
- static trajectory_rollout(state: State, system: System, *, n: Optional[int] = None, stride: int = 1, strides: Optional[jax.Array] = None, save_fn: Callable[[State, System], Any] = <function _save_state_system>, unroll: int = 2) Tuple[State, System, Any][source][source]#
Roll the system forward while collecting saved outputs at each frame.
The rollout always stores one output per frame via save_fn(state, system). The output of save_fn must be a pytree. Frame spacing can be either: - constant (stride), or - variable (strides jax.Array).
- Parameters:
state (State) – Initial state.
system (System) – Initial system configuration.
n (int, optional) – Number of saved frames. Required when strides is None. Ignored when strides is provided.
stride (int, optional) – Constant number of integration steps between consecutive saves. Used only when strides is None. Defaults to 1.
strides (jax.Array, optional) – Integer 1D array of per-frame integration strides. When provided, this overrides stride, and n is inferred from len(strides).
save_fn (Callable[[State, System], Any], optional) – Function called after each saved frame. Its return pytree is stacked along axis 0 across frames. Defaults to returning (state, system).
unroll (int, optional) – Unroll factor passed to the outer jax.lax.scan. Defaults to 2.
- Returns:
(final_state, final_system, trajectory_like) where trajectory_like is the stacked output of save_fn.
- Return type:
- Raises:
ValueError – If n is missing while strides is None, or if strides is not 1D.
Example
>>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> state = jdem.utils.grid_state(n_per_axis=(1, 1), spacing=1.0, radius=0.1) >>> system = jdem.System.create(state_shape=state.shape, dt=0.01) >>> >>> # Constant stride: n is required >>> final_state, final_system, traj = jdem.System.trajectory_rollout( ... state, system, n=10, stride=5 ... ) >>> >>> # Variable strides: n inferred from len(strides) >>> deltas = jnp.array([1, 2, 4, 8]) >>> final_state, final_system, traj = jdem.System.trajectory_rollout( ... state, system, strides=deltas ... )
- static step(state: State, system: System, *, n: int | jax.Array = 1) Tuple[State, System][source][source]#
Advance the simulation by n integration steps.
- Parameters:
- Returns:
(final_state, final_system) after n steps.
- Return type:
Example
>>> # Advance by 10 steps >>> state_after_10_steps, system_after_10_steps = jdem.System.step(state, system, n=10)
- static stack(systems: Sequence[System]) System[source][source]#
Concatenates a sequence of
Systemsnapshots into a trajectory or batch along axis 0.This method is useful for collecting simulation snapshots over time into a single System object where the leading dimension represents time or when preparing a batched system.
- Parameters:
systems (Sequence[System]) – A sequence (e.g., list, tuple) of
Systeminstances to be stacked.- Returns:
A new
Systeminstance where each attribute is a JAX array with an additional leading dimension representing the stacked trajectory. For example, if input pos was (N, dim), output pos will be (T, N, dim).- Return type:
- static unstack(system: System) list[System][source][source]#
Split a stacked/batched
Systemalong the leading axis into a Python list.This is the convenient inverse of
System.stack():If stacked = System.stack([sys0, sys1, …]), then System.unstack(stacked) returns [sys0, sys1, …].
Notes
The split is performed along axis 0 (the leading axis).
A single snapshot System cannot be unstacked with this method.