The Simulation System.#

Now that we know how to use and manipulate the simulation state (State), it’s time to delve into the simulation configuration in System.

A System holds the “static” configuration of a simulation, such as the domain, integrator settings, and force model. Although we call it “static”, many fields (e.g., the time step \(\Delta t\), domain dimensions, boundary conditions) can be changed at runtime—even inside a JIT-compiled function—because both State and System are JAX pytrees.

System Creation#

By default, create() initializes unspecified attributes (e.g., domain, force_model, \(\Delta t\)) with sensible defaults.

import jax
import jax.numpy as jnp
import jaxdem as jdem

The system’s dimension must match the state’s dimension. Some components (e.g., domains) transform arrays of shape \((N, d)\) and require \(d\) to agree with the system.

state = jdem.State.create(pos=jnp.zeros((1, 2)))
system = jdem.System.create(state.shape)
state, system = system.step(state, system)  # one step

A note on static methods#

Every operation on State and System (step, trajectory_rollout, merge, stack, etc.) is a static method. That means system.step(state, system) and jdem.System.step(state, system) are equivalent. Static methods make it straightforward to use these operations inside jax.jit(), jax.vmap(), and other JAX transforms.

Configuring the System#

You can configure submodules when creating the system via keyword arguments.

system = jdem.System.create(state.shape, domain_type="periodic")
print("periodic domain:", system.domain)
periodic domain: PeriodicDomain(box_size=Array([1., 1.], dtype=float32), anchor=Array([0., 0.], dtype=float32))

You can also pass constructor arguments to submodules via *_kw dictionaries.

system = jdem.System.create(
    state.shape,
    domain_type="periodic",
    domain_kw={"box_size": 10.0 * jnp.ones(2), "anchor": jnp.zeros(2)},
)
print("periodic domain (10x10):", system.domain)
periodic domain (10x10): PeriodicDomain(box_size=Array([10., 10.], dtype=float32), anchor=Array([0., 0.], dtype=float32))

Manually swapping a submodule#

Internally, create() builds each submodule and performs sanity checks. You can also create a submodule manually and replace it using:

domain = jdem.Domain.create("free", dim=2)
system.domain = domain
print("free default domain:", system.domain)
free default domain: FreeDomain(box_size=Array([1., 1.], dtype=float32), anchor=Array([0., 0.], dtype=float32))

Time stepping#

The system controls how the simulation advances in time. You can take a single step or multiple steps at once. Multi-step calls use jax.lax.fori_loop() internally for speed.

state = jdem.State.create(jnp.zeros((1, 2)))
state, system = system.step(state, system)  # 1 step

# Multiple steps in a single call:
state, system = system.step(state, system, n=10)  # 10 steps

Trajectory rollout#

If you want to store snapshots along the way, use trajectory_rollout(). It records n snapshots separated by stride integration steps each, for a total of \(n \times \text{stride}\) steps.

state = jdem.State.create(jnp.zeros((1, 2)))

state, system, trajectory = system.trajectory_rollout(
    state, system, n=10, stride=2  # total steps = 20
)

The trajectory is a Tuple[State, System] with an extra leading axis of length n.

traj_state, traj_system = trajectory
print("trajectory pos shape:", traj_state.pos.shape)  # (n, N, d)
trajectory pos shape: (10, 1, 2)

Batched simulations with vmap#

You can run many independent simulations in parallel with jax.vmap(). Make sure the initialization returns per-simulation State/System pairs.

def initialize(i):
    st = jdem.State.create(jnp.zeros((1, 2)))
    sys = jdem.System.create(
        st.shape,
        domain_type="reflect",
        domain_kw={"box_size": (2 + i) * jnp.ones(2), "anchor": jnp.zeros(2)},
    )
    return st, sys


# Create a batch of 5 simulations
state_b, system_b = jax.vmap(initialize)(jnp.arange(5))
print(system_b.domain)  # batched variable domain
ReflectDomain(box_size=Array([[2., 2.],
       [3., 3.],
       [4., 4.],
       [5., 5.],
       [6., 6.]], dtype=float32), anchor=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), restitution_coefficient=Array([1., 1., 1., 1., 1.], dtype=float32))

Advance each simulation by 10 steps. Use the class method (or a small wrapper) to avoid variable shadowing.

state_b, system_b = jax.vmap(lambda st, sys: jdem.System.step(st, sys, n=10))(
    state_b, system_b
)
print("batched pos shape:", state_b.pos.shape)  # (batch, N, d)
batched pos shape: (5, 1, 2)

Another way to create batch systems is the stack method:

state = jdem.State.create(jnp.zeros((1, 2)))
system = jdem.System.create(
    state.shape,
)

system = system.stack([system, system, system])
print("stacked system:", system)
stacked system: System(linear_integrator=VelocityVerlet(), rotation_integrator=VelocityVerletSpiral(), collider=NaiveSimulator(), domain=FreeDomain(box_size=Array([[1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32), anchor=Array([[0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)), force_manager=ForceManager(gravity=Array([[0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), external_force=Array([[[0., 0.]],

       [[0., 0.]],

       [[0., 0.]]], dtype=float32), external_force_com=Array([[[0., 0.]],

       [[0., 0.]],

       [[0., 0.]]], dtype=float32), external_torque=Array([[[0.]],

       [[0.]],

       [[0.]]], dtype=float32), is_com_force=(), force_functions=(), energy_functions=()), bonded_force_model=None, force_model=SpringForce(laws=()), mat_table=MaterialTable(props={'density': Array([[0.27],
       [0.27],
       [0.27]], dtype=float32), 'poisson': Array([[0.3],
       [0.3],
       [0.3]], dtype=float32), 'young': Array([[10000.],
       [10000.],
       [10000.]], dtype=float32)}, pair={'density_eff': Array([[[0.27]],

       [[0.27]],

       [[0.27]]], dtype=float32), 'poisson_eff': Array([[[0.3]],

       [[0.3]],

       [[0.3]]], dtype=float32), 'young_eff': Array([[[10000.]],

       [[10000.]],

       [[10000.]]], dtype=float32)}, matcher=HarmonicMaterialMatchmaker()), dt=Array([0.005, 0.005, 0.005], dtype=float32), time=Array([0., 0., 0.], dtype=float32), dim=Array([2, 2, 2], dtype=int32), step_count=Array([0, 0, 0], dtype=int32), key=Array([[0, 0],
       [0, 0],
       [0, 0]], dtype=uint32), interact_same_bond_id=Array([False, False, False], dtype=bool), user_pre_step_actions=<function _save_state_system at 0x7f9330ef1a80>, user_pos_step_actions=<function _save_state_system at 0x7f9330ef1a80>)

Deactivating Components#

Some modules can be deactivated by passing an empty string "" (or None, depending on the field) when creating the system. The base class is then used, which provides no-op behaviour.

Component

Deactivation value

Effect

collider_type

""

No pairwise force computation; forces/torques are zeroed.

linear_integrator_type

""

No position/velocity updates.

rotation_integrator_type

""

No orientation/angular-velocity updates.

bonded_force_model_type

None (default)

No bonded forces.

force_manager_kw -> gravity

None (default)

No gravitational acceleration.

Note: the domain (domain_type) and force model (force_model_type) cannot be deactivated — a valid type must always be provided.

# No collisions, no integration — a "frozen" system:
system_frozen = jdem.System.create(
    state.shape,
    collider_type="",
    linear_integrator_type="",
    rotation_integrator_type="",
)
print("Collider:", type(system_frozen.collider).__name__)
print("Integrator:", type(system_frozen.linear_integrator).__name__)
Collider: Collider
Integrator: LinearIntegrator

Random Number Generation#

create() accepts a seed (integer) or a key (jax.random.PRNGKey()) that initialises the system’s JAX PRNG state. The key is stored in system.key and is available for stochastic integrators or custom force functions.

system_rng = jdem.System.create(state.shape, seed=42)
print("PRNG key:", system_rng.key)
PRNG key: [ 0 42]

Total running time of the script: (0 minutes 4.042 seconds)