Note
Go to the end to download the full example code.
Colliders#
A Collider is the component that detects
interacting particle pairs and evaluates the
ForceModel for each pair. Different colliders
implement different spatial-search strategies, trading generality for speed.
This guide covers:
The available collider implementations and when to use each one.
How to configure a collider via
collider_type/collider_kw.How the collider interacts with force models and the force manager.
Computing potential energy through the collider.
Neighbor-list creation for diagnostics and caching.
Selecting a Collider#
The collider is chosen via collider_type when creating a
System. The default is "naive".
import jax.numpy as jnp
import jaxdem as jdem
state = jdem.State.create(
pos=jnp.array([[0.0, 0.0], [1.5, 0.0], [3.0, 0.0]]),
rad=jnp.array([1.0, 1.0, 1.0]),
)
system = jdem.System.create(state.shape, collider_type="naive")
print("Collider:", type(system.collider).__name__)
Collider: NaiveSimulator
Available Colliders#
JaxDEM provides several collider implementations registered in the
Collider factory:
|
Class |
Complexity |
Best for |
|---|---|---|---|
|
\(O(N^2)\) |
Small systems (< 1k–4k particles) |
|
|
\(O(N \log N)\) |
Low to moderate polydispersity systems and clumps |
|
|
\(O(N \cdot max\_hashes \log (N \cdot max\_hashes))\) |
Highly polydisperse systems (wide size distributions) |
|
|
\(O(2^{dim-1} \cdot N \log N)\) |
Low to moderate polydispersity systems, the denser or more overlaps cost increases linearlly. Fastest neighbor list creator. |
|
|
\(O(N)\) amortised |
Large systems with infrequent neighbor-list rebuilds |
The registered colliders are:
print("Colliders:", list(jdem.Collider._registry.keys()))
Colliders: ['', 'celllist', 'multicelllist', 'naive', 'neighborlist', 'sweepandprune']
The Naive Collider#
The NaiveSimulator evaluates the
force model for every pair \((i, j)\), giving \(O(N^2)\)
complexity. It requires no configuration and is the default.
This is by far the fastest option for small systems because it has no
overhead, but it becomes prohibitively expensive as \(N\) grows.
system_naive = jdem.System.create(state.shape, collider_type="naive")
state_out, system_out = system_naive.step(state, system_naive)
print("Forces after one step:\n", state_out.force)
Forces after one step:
[[-5000. 0.]
[ 0. 0.]
[ 5000. 0.]]
The Cell List Collider#
DynamicCellList (registered as "CellList")
partitions space into a regular grid. Only particles in the same or
neighboring cells interact. It uses an implicit infinite grid, so it works for all domain
types (periodic, free, etc.).
It probes each cell with a jax.lax.while_loop, making it robust to high or variable
cell occupancy—ideal for polydisperse systems and clumps.
Key parameters (all have automatic defaults):
cell_size— edge length of each grid cell.box_size— domain size (optional; only needed when the box size is small compared with the cell size to ensure correct periodic wrap stencil dimensions).
Important: cell-list colliders sort/reorder the state internally for traversal performance. The returned state follows that sorted ordering.
state_p = jdem.State.create(
pos=jnp.array([[1.0, 1.0], [3.0, 3.0], [5.0, 5.0]]),
rad=jnp.array([0.5, 0.5, 0.5]),
)
system_cl = jdem.System.create(
state_p.shape,
collider_type="CellList",
collider_kw={"state": state_p},
)
print("Cell size:", getattr(system_cl.collider, "cell_size", "n/a"))
Cell size: 1.0
The Multi-Cell List Collider#
DynamicMultiCellList (registered as "MultiCellList")
partitions space into a regular grid of size cell_size. Unlike standard cell lists
where the cell size is bounded by the largest particle diameter, the multi-cell list allows
particles to span/register in multiple cells (up to max_hashes cells).
This formulation is exceptionally well-suited for systems with extreme polydispersity, as it prevents a few large particles from forcing a large cell size for all the small particles.
Key parameters (all have automatic defaults):
cell_size— edge length of each grid cell. If None, it defaults to the median particle diameter.max_hashes— maximum number of cells a single particle is allowed to overlap.
Like the standard cell list, it sorts/reorders the state internally for performance.
system_mcl = jdem.System.create(
state_p.shape,
collider_type="MultiCellList",
collider_kw={"state": state_p},
)
print("Multi-Cell List cell size:", getattr(system_mcl.collider, "cell_size", "n/a"))
print("Multi-Cell List max hashes:", getattr(system_mcl.collider, "max_hashes", "n/a"))
Multi-Cell List cell size: 1.0
Multi-Cell List max hashes: 4
# The Sweep and Prune Collider
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# :py:class:`~jaxdem.colliders.sweep_and_prune.SweepAndPrune` (registered as
# ``"SweepAndPrune"``) projects and sorts particles along one axis. It performs a 1D sweep to find
# overlapping candidates, dynamically shifting perpendicular coordinates to support periodic boundary conditions.
#
# This is best suited for low density systems, sheared flows, or systems with few overlaps where
# axis sorting is highly efficient.
#
# Key parameters:
#
# - ``K`` — static search window radius (number of sorted neighbors to check in each direction).
system_sap = jdem.System.create(
state_p.shape,
collider_type="SweepAndPrune",
collider_kw={"state": state_p},
)
print("Sweep and Prune K window size:", getattr(system_sap.collider, "K", "n/a"))
Sweep and Prune K window size: 6
Neighbor-list creation for all colliders#
Every collider implements
create_neighbor_list(). This is useful
both for diagnostics and for algorithms that need explicit neighbors.
The API returns:
neighbor_listwith shape(N, max_neighbors)padded with-1.overflowflag, which isTrueif any particle had more thanmax_neighborsneighbors within the requested cutoff.
Note
Verifying Neighbor List Capacity with the Overflow Flag
Since max_neighbors is a static, user-provided buffer size required for JAX compile-time sizing,
checking the returned overflow flag is the correct way to verify that your simulation is working
correctly. If overflow is True, some particles have more neighbors than max_neighbors,
meaning some interactions may be truncated. If this occurs, you must increase the max_neighbors
parameter to ensure physical correctness.
Example with a regular collider (here: Cell List):
_, _, nl_cl, overflow_cl = system_cl.collider.create_neighbor_list(
state_p, system_cl, cutoff=2.0, max_neighbors=8
)
print("Cell-list neighbor list shape:", nl_cl.shape)
print("Cell-list overflow:", bool(overflow_cl))
Cell-list neighbor list shape: (3, 8)
Cell-list overflow: False
The Neighbor List collider#
NeighborList caches a
per-particle list of neighbors built with a secondary collider
(by default, the cell list). The list is rebuilt only when
some particle has moved more than skin / 2. Between rebuilds, the
cost is \(O(N)\).
Key parameters:
cutoff— physical interaction radius.skin— buffer distance. Must be > 0 for performance.max_neighbors— buffer size per particle (auto-estimated if omitted).secondary_collider_type— any registered collider except another"NeighborList".
This design works because every collider exposes create_neighbor_list.
A NeighborList wrapping another NeighborList is not meaningful and
should be avoided.
When a rebuild occurs, ordering may change according to the secondary collider’s sorting behaviour.
system_nl = jdem.System.create(
state_p.shape,
collider_type="NeighborList",
collider_kw={
"state": state_p,
"cutoff": 2.0,
"skin": 0.1,
"secondary_collider_type": "CellList",
"secondary_collider_kw": {"state": state_p},
"max_neighbors": 8,
},
)
print("Neighbor list collider:", type(system_nl.collider).__name__)
print("Cutoff:", float(getattr(system_nl.collider, "cutoff", jnp.nan)))
print("Skin:", float(getattr(system_nl.collider, "skin", jnp.nan)))
print("Max neighbors:", getattr(system_nl.collider, "max_neighbors", "n/a"))
print("Number of builds:", getattr(system_nl.collider, "n_build_times", "n/a"))
print("Last build overflow:", bool(getattr(system_nl.collider, "overflow", False)))
Neighbor list collider: NeighborList
Cutoff: 2.0
Skin: 0.2
Max neighbors: 3
Number of builds: 0
Last build overflow: False
Computing Potential Energy#
The collider exposes
compute_potential_energy(), which
sums all pairwise interaction energies as defined by the force model,
per particle:
state_pe = jdem.State.create(
pos=jnp.array([[0.0, 0.0], [1.5, 0.0]]),
rad=jnp.array([1.0, 1.0]),
)
system_pe = jdem.System.create(state_pe.shape, force_model_type="spring")
pe = system_pe.collider.compute_potential_energy(state_pe, system_pe)
print("Per particle PE energy:", pe)
Per particle PE energy: 1250.0
How the Collider Fits in the Step Pipeline#
During each integration step, the pipeline is:
Domain — applies boundary conditions.
Integrator (before force) — advances positions a half-step.
Collider — evaluates pairwise forces and writes
state.force/state.torque.Force manager — adds gravity, external forces, custom force functions, and performs rigid-body aggregation.
Integrator (after force) — advances velocities.
The collider only writes the pairwise contact contributions and resets forces; the force manager then adds everything else on top.
Total running time of the script: (0 minutes 1.716 seconds)