jaxdem.system#
Defines the simulation configuration and the tooling for driving the simulation.
Classes
|
Encapsulates the entire simulation configuration. |
- final class jaxdem.system.System(linear_integrator: LinearIntegrator, rotation_integrator: RotationIntegrator, collider: Collider, domain: Domain, force_manager: ForceManager, force_model: ForceModel, mat_table: MaterialTable, dt: Array, time: Array, dim: Array, step_count: Array)[source]#
Bases:
objectEncapsulates the entire simulation configuration.
Notes
The System object is designed to be JIT-compiled for efficient execution.
The System dataclass is compatible with
jax.jit(), so every field should remain JAX arrays for best performance.
Example
Creating a basic 2D simulation system:
>>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> # Create a System instance >>> sim_system = jdem.System.create( >>> state_shape=state.shape, >>> dt=0.001, >>> linear_integrator_type="euler", >>> rotation_integrator_type="spiral", >>> collider_type="naive", >>> domain_type="free", >>> force_model_type="spring", >>> # You can pass keyword arguments to component constructors via '_kw' dicts >>> domain_kw=dict(box_size=jnp.array([5.0, 5.0]), anchor=jnp.array([0.0, 0.0])) >>> ) >>> >>> print(f"System integrator: {sim_system.linear_integrator.__class__.__name__}") >>> print(f"System force model: {sim_system.force_model.__class__.__name__}") >>> print(f"Domain box size: {sim_system.domain.box_size}")
- linear_integrator: LinearIntegrator#
Instance of
jaxdem.LinearIntegratorthat advances the simulation linear state in time.
- rotation_integrator: RotationIntegrator#
Instance of
jaxdem.RotationIntegratorthat advances the simulation angular state in time.
- collider: Collider#
Instance of
jaxdem.Colliderthat performs contact detection and computes inter-particle forces and potential energies.
- domain: Domain#
Instance of
jaxdem.Domainthat defines the simulation boundaries, displacement rules, and boundary conditions.
- force_manager: ForceManager#
Instance of
jaxdem.ForceManagerthat handles per particle forces like external forces and resets forces.
- force_model: ForceModel#
Instance of
jaxdem.ForceModelthat defines the physical laws for inter-particle interactions.
- mat_table: MaterialTable#
Instance of
jaxdem.MaterialTableholding material properties and pairwise interaction parameters.
- dt: Array#
The global simulation time step \(\Delta t\).
- time: Array#
Elapsed simulation time.
- dim: Array#
Spatial dimension of the system.
- step_count: Array#
Number of integration steps that have been performed.
- static create(state_shape: Tuple[int], *, dt: float = 0.005, time: float = 0.0, linear_integrator_type: str = 'verlet', rotation_integrator_type: str = 'verletspiral', collider_type: str = 'naive', domain_type: str = 'free', force_model_type: str = 'spring', force_manager_kw: Dict[str, Any] | None = None, mat_table: MaterialTable | None = None, linear_integrator_kw: Dict[str, Any] | None = None, rotation_integrator_kw: Dict[str, Any] | None = None, collider_kw: Dict[str, Any] | None = None, domain_kw: Dict[str, Any] | None = None, force_model_kw: Dict[str, Any] | None = None) System[source][source]#
Factory method to create a
Systeminstance with specified components.- Parameters:
state_shape (Tuple) – Shape of the state tensors handled by the simulation. The penultimate dimension corresponds to the number of particles
Nand the last dimension corresponds to the spatial dimensiondim.dt (float, optional) – The global simulation time step.
linear_integrator_type (str, optional) – The registered type string for the
jaxdem.integrators.LinearIntegratorused to evolve translational degrees of freedom.rotation_integrator_type (str, optional) – The registered type string for the
jaxdem.integrators.RotationIntegratorused to evolve angular degrees of freedom.collider_type (str, optional) – The registered type string for the
jaxdem.Colliderto use.domain_type (str, optional) – The registered type string for the
jaxdem.Domainto use.force_model_type (str, optional) – The registered type string for the
jaxdem.ForceModelto use.force_manager_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of ForceManager.
mat_table (MaterialTable or None, optional) – An optional pre-configured
jaxdem.MaterialTable. If None, a default jaxdem.MaterialTable will be created with one generic elastic material and “harmonic” jaxdem.MaterialMatchmaker.linear_integrator_kw (Dict[str, Any] or None, optional) – Keyword arguments forwarded to the constructor of the selected LinearIntegrator type.
rotation_integrator_kw (Dict[str, Any] or None, optional) – Keyword arguments forwarded to the constructor of the selected RotationIntegrator type.
collider_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected Collider type.
domain_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected Domain type.
force_model_kw (Dict[str, Any] or None, optional) – Keyword arguments to pass to the constructor of the selected ForceModel type.
- Returns:
A fully configured System instance ready for simulation.
- Return type:
- Raises:
KeyError – If a specified *_type is not registered in its respective factory, or if the mat_table is missing properties required by the force_model.
TypeError – If constructor keyword arguments are invalid for any component.
AssertionError – If the domain_kw ‘box_size’ or ‘anchor’ shapes do not match the dim.
Example
Creating a 3D system with reflective boundaries and a custom dt:
>>> import jaxdem as jdem >>> import jax.numpy as jnp >>> >>> system_reflect = jdem.System.create( >>> state_shape=(N, 3), >>> dt=0.0005, >>> domain_type="reflect", >>> domain_kw=dict(box_size=jnp.array([20.0, 20.0, 20.0]), anchor=jnp.array([-10.0, -10.0, -10.0])), >>> force_model_type="spring", >>> ) >>> print(f"System dt: {system_reflect.dt}") >>> print(f"Domain type: {system_reflect.domain.__class__.__name__}")
Creating a system with a pre-defined MaterialTable:
>>> custom_mat_kw = dict(young=2.0e5, poisson=0.25) >>> custom_material = jdem.Material.create("custom_mat", **custom_mat_kw) >>> custom_mat_table = jdem.MaterialTable.from_materials( ... [custom_material], matcher=jdem.MaterialMatchmaker.create("linear") ... ) >>> >>> system_custom_mat = jdem.System.create( ... state_shape=(N, 2), ... mat_table=custom_mat_table, ... force_model_type="spring" ... )
- static trajectory_rollout(state: State, system: System, *, n: int, unroll: int = 2, stride: int = 1) Tuple['State', 'System', Tuple['State', 'System']][source][source]#
Rolls the system forward for a specified number of frames, collecting a trajectory.
This method performs n * stride total simulation steps, but it saves the State and System every stride steps, returning a trajectory of n snapshots. This is highly efficient for data collection within JAX as it leverages jax.lax.scan.
- stateState
The initial state of the simulation.
- systemSystem
The initial system configuration.
- nint
The number of frames (snapshots) to collect in the trajectory. This argument must be static.
- strideint, optional
The number of integration steps to advance between each collected frame. Defaults to 1, meaning every step is collected. This argument must be static.
- Tuple[State, System, Tuple[State, System]]
A tuple containing:
- final_state:
The State object at the end of the rollout.
- final_system:
The System object at the end of the rollout.
- trajectory:
A tuple of State and System objects, where each leaf array has an additional leading axis of size n representing the trajectory. The State and System objects within trajectory are structured as if created by
jaxdem.State.stack()and a similarjaxdem.Systemstack.
- ValueError
If n or stride are non-positive. (Implicit by jax.lax.scan length)
>>> import jaxdem as jdem >>> >>> state = jdem.utils.grid_state(n_per_axis=(1, 1), spacing=1.0, radius=0.1) >>> system = jdem.System.create(state_shape=state.shape, dt=0.01) >>> >>> # Roll out for 10 frames, saving every 5 steps >>> final_state, final_system, traj = jdem.System.trajectory_rollout( ... state, system, n=10, stride=5 ... ) >>> >>> print(f"Total simulation steps performed: {10 * 5}") >>> print(f"Trajectory length (number of frames): {traj[0].pos.shape[0]}") # traj[0] is the state part of the trajectory >>> print(f"First frame position:
- {traj[0].pos[0]}”)
>>> print(f"Last frame position:
- {traj[0].pos[-1]}”)
>>> print(f"Final state position (should match last frame):
{final_state.pos}”)
- static step(state: State, system: System, *, n: int = 1, unroll: int = 2) Tuple['State', 'System'][source][source]#
Advances the simulation state by n time steps.
This method provides a convenient way to run multiple integration steps. For a single step (n=1), it directly calls the integrator’s step method. For multiple steps (n > 1), it uses an optimized internal loop based on jax.lax.scan to maintain JIT-compilation efficiency.
- Parameters:
- Returns:
A tuple containing the final State and System after n integration steps.
- Return type:
- Raises:
ValueError – If n is non-positive. (Implicit by jax.lax.scan length check).
Example
>>> import jaxdem as jdem >>> >>> state = jdem.utils.grid_state(n_per_axis=(1, 1), spacing=1.0, radius=0.1) >>> system = jdem.System.create(state.force.shape, dt=0.01) >>> >>> # Advance by 1 step >>> state_after_1_step, system_after_1_step = jdem.System.step(state, system) >>> print("Position after 1 step:", state_after_1_step.pos[0]) >>> >>> # Advance by 10 steps >>> state_after_10_steps, system_after_10_steps = jdem.System.step(state, system, n=10) >>> print("Position after 10 steps:", state_after_10_steps.pos[0])
- static stack(systems: Sequence[System]) System[source][source]#
Concatenates a sequence of
Systemsnapshots into a trajectory or batch along axis 0.This method is useful for collecting simulation snapshots over time into a single System object where the leading dimension represents time or when preparing a batched system.
- Parameters:
systems (Sequence[System]) – A sequence (e.g., list, tuple) of
Systeminstances to be stacked.- Returns:
A new
Systeminstance where each attribute is a JAX array with an additional leading dimension representing the stacked trajectory. For example, if input pos was (N, dim), output pos will be (T, N, dim).- Return type: