Note
Go to the end to download the full example code.
Temperature and density control via the modern stack.
This script runs the same physical scenario twice to demonstrate the two orthogonal control axes available in JaxDEM today:
Temperature control is a choice of integrator. Picking
linear_integrator_type="verlet_rescaling"atSystem.createtime clamps the kinetic temperature to a fixed value on a configurable cadence; picking"verlet"lets temperature float.Density / box control is a protocol on top of whatever integrator you chose.
run_packing_fraction_protocol()integrates viasystem.stepin chunks and callsscale_to_packing_fraction()on a user-supplied per-frame schedule.
We impose a single-period sinusoidal modulation of the packing fraction and compare:
bare Verlet + phi modulation — temperature rises on compression, falls on expansion, as expected.
velocity-rescaling Verlet + phi modulation — temperature stays clamped while phi oscillates.
Imports
import jax
import jax.numpy as jnp
import numpy as np
jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call]
import jaxdem as jd
from jaxdem.utils.randomSphereConfiguration import random_sphere_configuration
from jaxdem.utils.dynamicsRoutines import run_packing_fraction_protocol
from jaxdem.utils.packingUtils import compute_packing_fraction
from jaxdem.utils.thermal import (
compute_potential_energy,
compute_temperature,
set_temperature,
)
Parameters
N = 100
phi0 = 0.70
dim = 2
dt = 1e-2
initial_temperature = 1.0
phi_amplitude = -0.05 # phi(t) = phi0 + amp * sin(2*pi*t), one full period
n_steps = 10_000
save_stride = 100 # one saved frame per 100 integration steps
n_frames = n_steps // save_stride
can_rotate = False # smooth spheres, no rotation DOF
subtract_drift = True
seed = 0
Precompute the phi schedule#
One phi value per saved frame. Each frame covers save_stride steps,
after which the box is rescaled to the frame’s target phi.
t_frac = (1 + np.arange(n_frames)) / n_frames # (0, 1]
phi_at_frames = phi0 + phi_amplitude * np.sin(2.0 * np.pi * t_frac)
strides = np.full(n_frames, save_stride, dtype=int)
System builder — two variants differ only in the linear integrator choice.
def build_state_system(linear_integrator_type, linear_integrator_kw=None):
particle_radii = jd.utils.dispersity.get_polydisperse_radii(N, [1.0], [1.0])
pos, box_size = random_sphere_configuration(
particle_radii.tolist(), phi0, dim, seed
)
state = jd.State.create(pos=pos, rad=particle_radii, mass=jnp.ones(N))
state = set_temperature(
state,
initial_temperature,
can_rotate=can_rotate,
subtract_drift=subtract_drift,
seed=seed,
)
mats = [jd.Material.create("elastic", young=1.0, poisson=0.5, density=1.0)]
mat_table = jd.MaterialTable.from_materials(
mats, matcher=jd.MaterialMatchmaker.create("harmonic")
)
system = jd.System.create(
state_shape=state.shape,
dt=dt,
linear_integrator_type=linear_integrator_type,
linear_integrator_kw=linear_integrator_kw or {},
rotation_integrator_type="",
domain_type="periodic",
force_model_type="spring",
collider_type="naive",
mat_table=mat_table,
domain_kw={"box_size": box_size},
)
return state, system
def summarize(label, traj_state, traj_system):
temperature = jax.vmap(
lambda s: compute_temperature(
s, can_rotate=can_rotate, subtract_drift=subtract_drift
)
)(traj_state)
phi_series = jax.vmap(compute_packing_fraction)(traj_state, traj_system)
pe = jax.vmap(compute_potential_energy)(traj_state, traj_system)
T_arr = np.asarray(temperature)
phi_arr = np.asarray(phi_series)
pe_arr = np.asarray(pe)
print(
f"[{label}] T: min={T_arr.min():.3e} max={T_arr.max():.3e} mean={T_arr.mean():.3e} "
f"phi: min={phi_arr.min():.4f} max={phi_arr.max():.4f} "
f"PE mean={pe_arr.mean():.3e}"
)
Bare Verlet + phi modulation: temperature is free to fluctuate.
state, system = build_state_system("verlet")
state, system, (traj_state, traj_system) = run_packing_fraction_protocol(
state,
system,
strides=strides,
phi_at_frames=phi_at_frames,
)
summarize("bare verlet", traj_state, traj_system)
[bare verlet] T: min=8.732e-01 max=9.332e-01 mean=9.066e-01 phi: min=0.6500 max=0.7500 PE mean=9.345e+00
verlet_rescaling thermostat + phi modulation: temperature clamped.
state, system = build_state_system(
"verlet_rescaling",
linear_integrator_kw=dict(
temperature=jnp.asarray(initial_temperature, dtype=float),
rescale_every=jnp.asarray(50, dtype=int),
can_rotate=jnp.asarray(int(can_rotate), dtype=int),
subtract_drift=jnp.asarray(int(subtract_drift), dtype=int),
k_B=jnp.asarray(1.0, dtype=float),
),
)
state, system, (traj_state, traj_system) = run_packing_fraction_protocol(
state,
system,
strides=strides,
phi_at_frames=phi_at_frames,
)
summarize("verlet_rescaling", traj_state, traj_system)
[verlet_rescaling] T: min=1.000e+00 max=1.000e+00 mean=1.000e+00 phi: min=0.6500 max=0.7500 PE mean=9.767e+00
Total running time of the script: (0 minutes 8.194 seconds)