Note
Go to the end to download the full example code.
The Simulation System#
Now that we know how to use and manipulate the simulation state
(jaxdem.state.State
), it’s time to delve into the simulation
configuration in jaxdem.system.System
.
A System
holds the “static” configuration of a
simulation, such as the domain, integrator settings, and force model. Although
we call it “static”, many fields (e.g., the time step \(\Delta t\),
domain dimensions, boundary conditions) can be changed at runtime—even inside a
JIT-compiled function—because both State
and
System
are JAX pytrees.
System Creation#
By default, jaxdem.system.System.create()
initializes unspecified
attributes (e.g., domain, force_model, \(\Delta t\)) with sensible defaults.
import dataclasses as _dc
import jax
import jax.numpy as jnp
import jaxdem as jdem
system = jdem.System.create(dim=2)
print(system)
System(integrator=DirectEuler(), collider=NaiveSimulator(), domain=FreeDomain(box_size=Array([1., 1.], dtype=float32), anchor=Array([0., 0.], dtype=float32)), force_model=SpringForce(required_material_properties=('young_eff',), laws=()), mat_table=MaterialTable(props={'young': Array([10000.], dtype=float32), 'poisson': Array([0.3], dtype=float32)}, pair={'young_eff': Array([[10000.]], dtype=float32), 'poisson_eff': Array([[0.3]], dtype=float32)}, matcher=HarmonicMaterialMatchmaker()), dt=Array(0.005, dtype=float32), time=Array(0., dtype=float32), dim=Array(2, dtype=int32), step_count=Array(0, dtype=int32))
It is essential that the system’s dimension matches the state’s dimension.
Some components, like those in jaxdem.domain
require matching dimensions because
they transform the state’s arrays of shape ( (N, d) ) where ( d ) must
agree with the system.
state = jdem.State.create(pos=jnp.zeros((1, 2)))
state, system = system.step(state, system) # one step
Configuring the System#
You can configure submodules when creating the system via keyword arguments.
system = jdem.System.create(dim=2, domain_type="periodic")
print("periodic domain:", system.domain)
periodic domain: PeriodicDomain(box_size=Array([1., 1.], dtype=float32), anchor=Array([0., 0.], dtype=float32))
You can also pass constructor arguments to submodules via *_kw dictionaries.
system = jdem.System.create(
dim=2,
domain_type="periodic",
domain_kw=dict(box_size=10.0 * jnp.ones(2), anchor=jnp.zeros(2)),
)
print("periodic domain (10x10):", system.domain)
periodic domain (10x10): PeriodicDomain(box_size=Array([10., 10.], dtype=float32), anchor=Array([0., 0.], dtype=float32))
Manually swapping a submodule#
Internally, jaxdem.system.System.create()
builds each submodule and
performs sanity checks. You can also create a submodule manually and replace
it using dataclasses.replace()
:
domain = jdem.Domain.create("free", dim=2)
system = _dc.replace(system, domain=domain)
print("free default domain:", system.domain)
free default domain: FreeDomain(box_size=Array([1., 1.], dtype=float32), anchor=Array([0., 0.], dtype=float32))
Time stepping#
The system controls how to advance the simulation in time.
You can perform a single step or multiple steps in a batch. When taking many
steps, batched stepping uses jax.lax.scan()
under the hood for speed.
state = jdem.State.create(jnp.zeros((1, 2)))
state, system = system.step(state, system) # 1 step
# Multiple steps in a single call:
state, system = system.step(state, system, n=10) # 10 steps
Trajectory rollout#
If you want to store snapshots along the way, use
jaxdem.system.System.trajectory_rollout()
. This records n snapshots,
taking stride internal steps between snapshots, i.e., a total of
\(n \cdot \text{stride}\) integration steps.
state = jdem.State.create(jnp.zeros((1, 2)))
state, system, trajectory = system.trajectory_rollout(
state, system, n=10, stride=2 # total steps = 20
)
The trajectory is a Tuple[State, System] with an extra leading snapshot axis of length n.
traj_state, traj_system = trajectory
print("trajectory pos shape:", traj_state.pos.shape) # (n, N, d)
trajectory pos shape: (10, 1, 2)
Batched simulations with vmap#
You can run many independent simulations in parallel with jax.vmap()
.
Make sure the initialization returns per-simulation State/System pairs.
def initialize(i):
st = jdem.State.create(jnp.zeros((1, 2)))
sys = jdem.System.create(
dim=2,
domain_type="reflect",
domain_kw=dict(box_size=(2 + i) * jnp.ones(2), anchor=jnp.zeros(2)),
)
return st, sys
# Create a batch of 5 simulations
state_b, system_b = jax.vmap(initialize)(jnp.arange(5))
print(system_b.domain) # batched variable domain
ReflectDomain(box_size=Array([[2., 2.],
[3., 3.],
[4., 4.],
[5., 5.],
[6., 6.]], dtype=float32), anchor=Array([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32))
Advance each simulation by 10 steps. Use the class method (or a small wrapper) to avoid variable shadowing.
state_b, system_b = jax.vmap(lambda st, sys: jdem.System.step(st, sys, n=10))(
state_b, system_b
)
print("batched pos shape:", state_b.pos.shape) # (batch, N, d)
batched pos shape: (5, 1, 2)
Total running time of the script: (0 minutes 0.889 seconds)