Note
Go to the end to download the full example code.
The Simulation System.#
Now that we know how to use and manipulate the simulation state
(State), it’s time to delve into the simulation
configuration in System.
A System holds the “static” configuration of a
simulation, such as the domain, integrator settings, and force model. Although
we call it “static”, many fields (e.g., the time step \(\Delta t\),
domain dimensions, boundary conditions) can be changed at runtime—even inside a
JIT-compiled function—because both State and
System are JAX pytrees.
System Creation#
By default, create() initializes unspecified
attributes (e.g., domain, force_model, \(\Delta t\)) with sensible defaults.
import jax
import jax.numpy as jnp
import jaxdem as jdem
The system’s dimension must match the state’s dimension. Some components (e.g., domains) transform arrays of shape \((N, d)\) and require \(d\) to agree with the system.
state = jdem.State.create(pos=jnp.zeros((1, 2)))
system = jdem.System.create(state.shape)
state, system = system.step(state, system) # one step
A note on static methods#
Every operation on State and
System (step, trajectory_rollout,
merge, stack, etc.) is a static method. That means
system.step(state, system) and jdem.System.step(state, system) are
equivalent. Static methods make it straightforward to use these operations
inside jax.jit(), jax.vmap(), and other JAX transforms.
Configuring the System#
You can configure submodules when creating the system via keyword arguments.
system = jdem.System.create(state.shape, domain_type="periodic")
print("periodic domain:", system.domain)
periodic domain: PeriodicDomain(box_size=Array([1., 1.], dtype=float32), anchor=Array([0., 0.], dtype=float32))
You can also pass constructor arguments to submodules via *_kw dictionaries.
system = jdem.System.create(
state.shape,
domain_type="periodic",
domain_kw={"box_size": 10.0 * jnp.ones(2), "anchor": jnp.zeros(2)},
)
print("periodic domain (10x10):", system.domain)
periodic domain (10x10): PeriodicDomain(box_size=Array([10., 10.], dtype=float32), anchor=Array([0., 0.], dtype=float32))
Manually swapping a submodule#
Internally, create() builds each submodule and
performs sanity checks. You can also create a submodule manually and replace
it using:
domain = jdem.Domain.create("free", dim=2)
system.domain = domain
print("free default domain:", system.domain)
free default domain: FreeDomain(box_size=Array([1., 1.], dtype=float32), anchor=Array([0., 0.], dtype=float32))
Time stepping#
The system controls how the simulation advances in time.
You can take a single step or multiple steps at once. Multi-step calls
use jax.lax.fori_loop() internally for speed.
state = jdem.State.create(jnp.zeros((1, 2)))
state, system = system.step(state, system) # 1 step
# Multiple steps in a single call:
state, system = system.step(state, system, n=10) # 10 steps
Trajectory rollout#
If you want to store snapshots along the way, use
trajectory_rollout(). It records n
snapshots separated by stride integration steps each, for a total
of \(n \times \text{stride}\) steps.
state = jdem.State.create(jnp.zeros((1, 2)))
state, system, trajectory = system.trajectory_rollout(
state, system, n=10, stride=2 # total steps = 20
)
The trajectory is a Tuple[State, System] with an extra leading
axis of length n.
traj_state, traj_system = trajectory
print("trajectory pos shape:", traj_state.pos.shape) # (n, N, d)
trajectory pos shape: (10, 1, 2)
Batched simulations with vmap#
You can run many independent simulations in parallel with jax.vmap().
Make sure the initialization returns per-simulation State/System pairs.
def initialize(i):
st = jdem.State.create(jnp.zeros((1, 2)))
sys = jdem.System.create(
st.shape,
domain_type="reflect",
domain_kw={"box_size": (2 + i) * jnp.ones(2), "anchor": jnp.zeros(2)},
)
return st, sys
# Create a batch of 5 simulations
state_b, system_b = jax.vmap(initialize)(jnp.arange(5))
print(system_b.domain) # batched variable domain
ReflectDomain(box_size=Array([[2., 2.],
[3., 3.],
[4., 4.],
[5., 5.],
[6., 6.]], dtype=float32), anchor=Array([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32), restitution_coefficient=Array([1., 1., 1., 1., 1.], dtype=float32))
Advance each simulation by 10 steps. Use the class method (or a small wrapper) to avoid variable shadowing.
state_b, system_b = jax.vmap(lambda st, sys: jdem.System.step(st, sys, n=10))(
state_b, system_b
)
print("batched pos shape:", state_b.pos.shape) # (batch, N, d)
batched pos shape: (5, 1, 2)
Another way to create batch systems is the stack method:
state = jdem.State.create(jnp.zeros((1, 2)))
system = jdem.System.create(
state.shape,
)
system = system.stack([system, system, system])
print("stacked system:", system)
stacked system: System(linear_integrator=VelocityVerlet(), rotation_integrator=VelocityVerletSpiral(), collider=NaiveSimulator(), domain=FreeDomain(box_size=Array([[1., 1.],
[1., 1.],
[1., 1.]], dtype=float32), anchor=Array([[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32)), force_manager=ForceManager(gravity=Array([[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32), external_force=Array([[[0., 0.]],
[[0., 0.]],
[[0., 0.]]], dtype=float32), external_force_com=Array([[[0., 0.]],
[[0., 0.]],
[[0., 0.]]], dtype=float32), external_torque=Array([[[0.]],
[[0.]],
[[0.]]], dtype=float32), is_com_force=(), force_functions=(), energy_functions=()), bonded_force_model=None, force_model=SpringForce(laws=()), mat_table=MaterialTable(props={'density': Array([[0.27],
[0.27],
[0.27]], dtype=float32), 'poisson': Array([[0.3],
[0.3],
[0.3]], dtype=float32), 'young': Array([[10000.],
[10000.],
[10000.]], dtype=float32)}, pair={'density_eff': Array([[[0.27]],
[[0.27]],
[[0.27]]], dtype=float32), 'poisson_eff': Array([[[0.3]],
[[0.3]],
[[0.3]]], dtype=float32), 'young_eff': Array([[[10000.]],
[[10000.]],
[[10000.]]], dtype=float32)}, matcher=HarmonicMaterialMatchmaker()), dt=Array([0.005, 0.005, 0.005], dtype=float32), time=Array([0., 0., 0.], dtype=float32), dim=Array([2, 2, 2], dtype=int32), step_count=Array([0, 0, 0], dtype=int32), key=Array([[0, 0],
[0, 0],
[0, 0]], dtype=uint32), interact_same_bond_id=Array([False, False, False], dtype=bool), user_pre_step_actions=<function _save_state_system at 0x7f9330ef1a80>, user_pos_step_actions=<function _save_state_system at 0x7f9330ef1a80>)
Deactivating Components#
Some modules can be deactivated by passing an empty string ""
(or None, depending on the field) when creating the system. The
base class is then used, which provides no-op behaviour.
Component |
Deactivation value |
Effect |
|---|---|---|
|
|
No pairwise force computation; forces/torques are zeroed. |
|
|
No position/velocity updates. |
|
|
No orientation/angular-velocity updates. |
|
|
No bonded forces. |
|
|
No gravitational acceleration. |
Note: the domain (domain_type) and force model
(force_model_type) cannot be deactivated — a valid type must
always be provided.
# No collisions, no integration — a "frozen" system:
system_frozen = jdem.System.create(
state.shape,
collider_type="",
linear_integrator_type="",
rotation_integrator_type="",
)
print("Collider:", type(system_frozen.collider).__name__)
print("Integrator:", type(system_frozen.linear_integrator).__name__)
Collider: Collider
Integrator: LinearIntegrator
Random Number Generation#
create() accepts a seed (integer) or
a key (jax.random.PRNGKey()) that initialises the system’s
JAX PRNG state. The key is stored in system.key and is available for
stochastic integrators or custom force functions.
system_rng = jdem.System.create(state.shape, seed=42)
print("PRNG key:", system_rng.key)
PRNG key: [ 0 42]
Total running time of the script: (0 minutes 4.042 seconds)