Note
Go to the end to download the full example code.
Custom Module Registration and System Usage.#
JaxDEM components are created through factory registries. This makes it easy
to add custom modules and use them with create().
To register a new module, the pattern is always the same:
Inherit from the corresponding base class (
ForceModel,Domain,Collider,LinearIntegrator, etc.).Implement the required abstract/interface methods for that base class.
Register the class with
@<Base>.register("your_key").
This guide shows how to:
define and register custom modules,
instantiate them via
*_typeand*_kwarguments,and pass pre-built module objects directly to a
System.
Setup#
We will create several custom modules:
a custom
ForceModela custom
Domaina custom
Collidercustom linear and rotation integrators
from dataclasses import dataclass
from functools import partial
from typing import cast
import jax
import jax.numpy as jnp
import jaxdem as jdem
Custom Force Model#
This force law applies a simple linear attraction between particle pairs:
It does not require any material-table properties.
Registration reminder:
- inherit from ForceModel
- implement force and energy
@jdem.ForceModel.register("pairattractor")
@jax.tree_util.register_dataclass
@dataclass(slots=True)
class PairAttractor(jdem.ForceModel):
k: float = 0.1
@staticmethod
@partial(jax.jit, inline=True)
def force(
i: int, j: int, pos: jax.Array, state: jdem.State, system: jdem.System
) -> tuple[jax.Array, jax.Array]:
rij = system.domain.displacement(pos[i], pos[j], system)
mask = jnp.asarray(i != j, dtype=rij.dtype)[..., None]
model = cast(PairAttractor, system.force_model)
force = -model.k * rij * mask
torque = jnp.zeros_like(state.torque[j])
return force, torque
@staticmethod
@partial(jax.jit, inline=True)
def energy(
i: int, j: int, pos: jax.Array, state: jdem.State, system: jdem.System
) -> jax.Array:
rij = system.domain.displacement(pos[i], pos[j], system)
model = cast(PairAttractor, system.force_model)
return 0.5 * model.k * jnp.sum(rij * rij, axis=-1) * (i != j)
print(
"ForceModel registry contains pairattractor:",
"pairattractor" in jdem.ForceModel._registry,
)
print("Registered ForceModels:", list(jdem.ForceModel._registry.keys()))
state = jdem.State.create(
pos=jnp.array([[0.0, 0.0], [2.0, 0.0]]),
rad=jnp.array([0.4, 0.4]),
)
system = jdem.System.create(
state.shape,
force_model_type="pairattractor",
force_model_kw={"k": 0.2},
)
state, system = system.step(state, system, n=5)
print("Custom force model:", type(system.force_model).__name__)
print("Positions after 5 steps:\n", state.pos)
ForceModel registry contains pairattractor: True
Registered ForceModels: ['lawcombiner', 'forcerouter', 'spring', 'wca', 'lennardjones', 'wca_shifted', 'hertz', 'cundallstrack', 'pairattractor']
Custom force model: PairAttractor
Positions after 5 steps:
[[9.99985e-05 0.00000e+00]
[1.99990e+00 0.00000e+00]]
Custom Domain#
This domain recenters particles every step so the center of mass stays at
the origin. It reuses the default Create() for box_size/anchor.
Registration reminder:
- inherit from Domain
- implement the relevant interface methods (here: apply)
@jdem.Domain.register("centered")
@jax.tree_util.register_dataclass
@dataclass(slots=True)
class CenteredDomain(jdem.Domain):
@staticmethod
@partial(jax.jit, inline=True)
def apply(state: jdem.State, system: jdem.System) -> tuple[jdem.State, jdem.System]:
center = jnp.mean(state.pos, axis=-2)
state.pos_c -= center
return state, system
print("Domain registry contains centered:", "centered" in jdem.Domain._registry)
print("Registered Domains:", list(jdem.Domain._registry.keys()))
state = jdem.State.create(
pos=jnp.array([[2.0, 0.0], [4.0, 0.0]]),
vel=jnp.array([[1.0, 0.0], [1.0, 0.0]]),
)
system = jdem.System.create(
state.shape,
domain_type="centered",
force_model_type="pairattractor",
)
state, system = system.step(state, system, n=3)
print("Custom domain:", type(system.domain).__name__)
print("Mean position after centering:", jnp.mean(state.pos, axis=0))
Domain registry contains centered: True
Registered Domains: ['free', 'periodic', 'reflect', 'reflectsphere', 'centered']
Custom domain: CenteredDomain
Mean position after centering: [0.005 0. ]
Custom Collider#
This collider disables all pair contacts by forcing zero force and torque. In reallity, this is the same as passing no collider to the system object, but it serves as a simple example of a custom collider.
Registration reminder:
- inherit from Collider
- implement interface methods (at minimum compute_force, and for full compatibility also compute_potential_energy and create_neighbor_list)
@jdem.Collider.register("nocontact")
@jax.tree_util.register_dataclass
@dataclass(slots=True)
class NoContactCollider(jdem.Collider):
@staticmethod
@partial(jax.jit, inline=True)
def compute_force(
state: jdem.State, system: jdem.System
) -> tuple[jdem.State, jdem.System]:
state.force *= 0
state.torque *= 0
return state, system
@staticmethod
@partial(jax.jit, inline=True)
def compute_potential_energy(state: jdem.State, system: jdem.System) -> jax.Array:
return jnp.zeros_like(state.mass)
@staticmethod
@jax.jit(static_argnames=("max_neighbors",))
def create_neighbor_list(
state: jdem.State,
system: jdem.System,
cutoff: float,
max_neighbors: int,
) -> tuple[jdem.State, jdem.System, jax.Array, jax.Array]:
del cutoff
nl = -jnp.ones((state.N, max_neighbors), dtype=int)
overflow = jnp.asarray(False)
return state, system, nl, overflow
print("Collider registry contains nocontact:", "nocontact" in jdem.Collider._registry)
print("Registered Colliders:", list(jdem.Collider._registry.keys()))
state = jdem.State.create(
pos=jnp.array([[0.0, 0.0], [1.0, 0.0]]),
vel=jnp.array([[0.5, 0.0], [-0.5, 0.0]]),
rad=jnp.array([0.6, 0.6]),
)
system = jdem.System.create(
state.shape,
collider_type="nocontact",
force_model_type="pairattractor",
)
state, system = system.step(state, system, n=2)
print("Custom collider:", type(system.collider).__name__)
print("Forces with nocontact collider:\n", state.force)
Collider registry contains nocontact: True
Registered Colliders: ['', 'naive', 'staticcelllist', 'celllist', 'neighborlist', 'sap', 'nocontact']
Custom collider: NoContactCollider
Forces with nocontact collider:
[[0. 0.]
[0. 0.]]
Custom Integrators#
Here we register:
DampedEulerfor linear motion,FrozenRotationfor angular motion.
Registration reminder:
- inherit from LinearIntegrator or RotationIntegrator
- implement the needed step methods (here: step_after_force)
@jdem.LinearIntegrator.register("dampedeuler")
@jax.tree_util.register_dataclass
@dataclass(slots=True)
class DampedEuler(jdem.LinearIntegrator):
damping: float = 0.2
@staticmethod
@partial(jax.jit, inline=True)
def step_after_force(
state: jdem.State, system: jdem.System
) -> tuple[jdem.State, jdem.System]:
accel = state.force / state.mass[..., None]
active = (1 - state.fixed)[..., None]
integrator = cast(DampedEuler, system.linear_integrator)
damp = 1.0 - system.dt * integrator.damping
state.vel = damp * state.vel + system.dt * accel * active
state.vel *= active
state.pos_c += system.dt * state.vel
return state, system
@jdem.RotationIntegrator.register("frozenrotation")
@jax.tree_util.register_dataclass
@dataclass(slots=True)
class FrozenRotation(jdem.RotationIntegrator):
@staticmethod
@partial(jax.jit, inline=True)
def step_after_force(
state: jdem.State, system: jdem.System
) -> tuple[jdem.State, jdem.System]:
state.ang_vel *= 0
return state, system
print(
"LinearIntegrator registry contains dampedeuler:",
"dampedeuler" in jdem.LinearIntegrator._registry,
)
print("Registered LinearIntegrators:", list(jdem.LinearIntegrator._registry.keys()))
print(
"RotationIntegrator registry contains frozenrotation:",
"frozenrotation" in jdem.RotationIntegrator._registry,
)
print("Registered RotationIntegrators:", list(jdem.RotationIntegrator._registry.keys()))
state = jdem.State.create(
pos=jnp.array([[0.0, 0.0], [2.0, 0.0]]),
rad=jnp.array([0.4, 0.4]),
)
system = jdem.System.create(
state.shape,
force_model_type="pairattractor",
linear_integrator_type="dampedeuler",
linear_integrator_kw={"damping": 0.5},
rotation_integrator_type="frozenrotation",
)
state, system = system.step(state, system, n=10)
print("Custom linear integrator:", type(system.linear_integrator).__name__)
print("Custom rotation integrator:", type(system.rotation_integrator).__name__)
print("Velocity after damping:\n", state.vel)
LinearIntegrator registry contains dampedeuler: True
Registered LinearIntegrators: ['', 'euler', 'langevin', 'verlet', 'verlet_rescaling', 'vicsek_extrinsic', 'vicsek_intrinsic', 'lineargradientdescent', 'linearfire', 'optax', 'dampedeuler']
RotationIntegrator registry contains frozenrotation: True
Registered RotationIntegrators: ['', 'spiral', 'verletspiral', 'rotationgradientdescent', 'rotationfire', 'optax', 'frozenrotation']
Custom linear integrator: DampedEuler
Custom rotation integrator: FrozenRotation
Velocity after damping:
[[ 0.00988743 0. ]
[-0.00988743 0. ]]
Passing Module Objects Directly#
create() is convenient for factory-based construction. You can also
build a base system and swap modules directly with pre-built objects.
state = jdem.State.create(
pos=jnp.array([[0.0, 0.0], [1.5, 0.0]]),
rad=jnp.array([0.4, 0.4]),
)
system = jdem.System.create(state.shape)
system.force_model = PairAttractor(k=0.05)
system.domain = CenteredDomain.Create(dim=state.dim)
system.collider = NoContactCollider()
system.linear_integrator = DampedEuler(damping=0.8)
system.rotation_integrator = FrozenRotation()
state, system = system.step(state, system, n=3)
print("Directly assigned force model:", type(system.force_model).__name__)
print("Directly assigned domain:", type(system.domain).__name__)
print("Directly assigned collider:", type(system.collider).__name__)
print("Directly assigned linear integrator:", type(system.linear_integrator).__name__)
print(
"Directly assigned rotation integrator:", type(system.rotation_integrator).__name__
)
Directly assigned force model: PairAttractor
Directly assigned domain: CenteredDomain
Directly assigned collider: NoContactCollider
Directly assigned linear integrator: DampedEuler
Directly assigned rotation integrator: FrozenRotation
Notes on Registration#
Registration keys are case-insensitive.
Registrations are process-local: define/register your custom classes before calling
create()with the corresponding*_type.All custom modules should be JAX pytrees; using
@jax.tree_util.register_dataclasson dataclasses is the recommended path.The “proof” that registration worked is that the key appears in the relevant registry dictionary.
print(
"Registered custom force model key exists:",
"pairattractor" in jdem.ForceModel._registry,
)
print("Registered custom domain key exists:", "centered" in jdem.Domain._registry)
print("Registered custom collider key exists:", "nocontact" in jdem.Collider._registry)
print(
"Registered custom linear integrator key exists:",
"dampedeuler" in jdem.LinearIntegrator._registry,
)
print(
"Registered custom rotation integrator key exists:",
"frozenrotation" in jdem.RotationIntegrator._registry,
)
Registered custom force model key exists: True
Registered custom domain key exists: True
Registered custom collider key exists: True
Registered custom linear integrator key exists: True
Registered custom rotation integrator key exists: True
Total running time of the script: (0 minutes 1.864 seconds)