Colliders#

A Collider is the component that detects interacting particle pairs and evaluates the ForceModel for each pair. Different colliders implement different spatial-search strategies, trading generality for speed.

This guide covers:

  • The available collider implementations and when to use each one.

  • How to configure a collider via collider_type / collider_kw.

  • How the collider interacts with force models and the force manager.

  • Computing potential energy through the collider.

  • Neighbor-list creation for diagnostics and caching.

Selecting a Collider#

The collider is chosen via collider_type when creating a System. The default is "naive".

import jax.numpy as jnp

import jaxdem as jdem

state = jdem.State.create(
    pos=jnp.array([[0.0, 0.0], [1.5, 0.0], [3.0, 0.0]]),
    rad=jnp.array([1.0, 1.0, 1.0]),
)
system = jdem.System.create(state.shape, collider_type="naive")
print("Collider:", type(system.collider).__name__)
Collider: NaiveSimulator

Available Colliders#

JaxDEM provides several collider implementations registered in the Collider factory:

collider_type

Class

Complexity

Best for

"naive"

NaiveSimulator

\(O(N^2)\)

Small systems (< 1k–4k particles)

"CellList"

DynamicCellList

\(O(N \log N)\)

Low to moderate polydispersity systems and clumps

"MultiCellList"

DynamicMultiCellList

\(O(N \cdot max\_hashes \log (N \cdot max\_hashes))\)

Highly polydisperse systems (wide size distributions)

"SweepAndPrune"

SweepAndPrune

\(O(2^{dim-1} \cdot N \log N)\)

Low to moderate polydispersity systems, the denser or more overlaps cost increases linearlly. Fastest neighbor list creator.

"NeighborList"

NeighborList

\(O(N)\) amortised

Large systems with infrequent neighbor-list rebuilds

The registered colliders are:

print("Colliders:", list(jdem.Collider._registry.keys()))
Colliders: ['', 'celllist', 'multicelllist', 'naive', 'neighborlist', 'sweepandprune']

The Naive Collider#

The NaiveSimulator evaluates the force model for every pair \((i, j)\), giving \(O(N^2)\) complexity. It requires no configuration and is the default. This is by far the fastest option for small systems because it has no overhead, but it becomes prohibitively expensive as \(N\) grows.

system_naive = jdem.System.create(state.shape, collider_type="naive")
state_out, system_out = system_naive.step(state, system_naive)
print("Forces after one step:\n", state_out.force)
Forces after one step:
 [[-5000.     0.]
 [    0.     0.]
 [ 5000.     0.]]

The Cell List Collider#

DynamicCellList (registered as "CellList") partitions space into a regular grid. Only particles in the same or neighboring cells interact. It uses an implicit infinite grid, so it works for all domain types (periodic, free, etc.).

It probes each cell with a jax.lax.while_loop, making it robust to high or variable cell occupancy—ideal for polydisperse systems and clumps.

Key parameters (all have automatic defaults):

  • cell_size — edge length of each grid cell.

  • box_size — domain size (optional; only needed when the box size is small compared with the cell size to ensure correct periodic wrap stencil dimensions).

Important: cell-list colliders sort/reorder the state internally for traversal performance. The returned state follows that sorted ordering.

state_p = jdem.State.create(
    pos=jnp.array([[1.0, 1.0], [3.0, 3.0], [5.0, 5.0]]),
    rad=jnp.array([0.5, 0.5, 0.5]),
)
system_cl = jdem.System.create(
    state_p.shape,
    collider_type="CellList",
    collider_kw={"state": state_p},
)
print("Cell size:", getattr(system_cl.collider, "cell_size", "n/a"))
Cell size: 1.0

The Multi-Cell List Collider#

DynamicMultiCellList (registered as "MultiCellList") partitions space into a regular grid of size cell_size. Unlike standard cell lists where the cell size is bounded by the largest particle diameter, the multi-cell list allows particles to span/register in multiple cells (up to max_hashes cells).

This formulation is exceptionally well-suited for systems with extreme polydispersity, as it prevents a few large particles from forcing a large cell size for all the small particles.

Key parameters (all have automatic defaults):

  • cell_size — edge length of each grid cell. If None, it defaults to the median particle diameter.

  • max_hashes — maximum number of cells a single particle is allowed to overlap.

Like the standard cell list, it sorts/reorders the state internally for performance.

system_mcl = jdem.System.create(
    state_p.shape,
    collider_type="MultiCellList",
    collider_kw={"state": state_p},
)
print("Multi-Cell List cell size:", getattr(system_mcl.collider, "cell_size", "n/a"))
print("Multi-Cell List max hashes:", getattr(system_mcl.collider, "max_hashes", "n/a"))
Multi-Cell List cell size: 1.0
Multi-Cell List max hashes: 4
# The Sweep and Prune Collider
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# :py:class:`~jaxdem.colliders.sweep_and_prune.SweepAndPrune` (registered as
# ``"SweepAndPrune"``) projects and sorts particles along one axis. It performs a 1D sweep to find
# overlapping candidates, dynamically shifting perpendicular coordinates to support periodic boundary conditions.
#
# This is best suited for low density systems, sheared flows, or systems with few overlaps where
# axis sorting is highly efficient.
#
# Key parameters:
#
# - ``K`` — static search window radius (number of sorted neighbors to check in each direction).

system_sap = jdem.System.create(
    state_p.shape,
    collider_type="SweepAndPrune",
    collider_kw={"state": state_p},
)
print("Sweep and Prune K window size:", getattr(system_sap.collider, "K", "n/a"))
Sweep and Prune K window size: 6

Neighbor-list creation for all colliders#

Every collider implements create_neighbor_list(). This is useful both for diagnostics and for algorithms that need explicit neighbors.

The API returns:

  • neighbor_list with shape (N, max_neighbors) padded with -1.

  • overflow flag, which is True if any particle had more than max_neighbors neighbors within the requested cutoff.

Note

Verifying Neighbor List Capacity with the Overflow Flag

Since max_neighbors is a static, user-provided buffer size required for JAX compile-time sizing, checking the returned overflow flag is the correct way to verify that your simulation is working correctly. If overflow is True, some particles have more neighbors than max_neighbors, meaning some interactions may be truncated. If this occurs, you must increase the max_neighbors parameter to ensure physical correctness.

Example with a regular collider (here: Cell List):

_, _, nl_cl, overflow_cl = system_cl.collider.create_neighbor_list(
    state_p, system_cl, cutoff=2.0, max_neighbors=8
)
print("Cell-list neighbor list shape:", nl_cl.shape)
print("Cell-list overflow:", bool(overflow_cl))
Cell-list neighbor list shape: (3, 8)
Cell-list overflow: False

The Neighbor List collider#

NeighborList caches a per-particle list of neighbors built with a secondary collider (by default, the cell list). The list is rebuilt only when some particle has moved more than skin / 2. Between rebuilds, the cost is \(O(N)\).

Key parameters:

  • cutoff — physical interaction radius.

  • skin — buffer distance. Must be > 0 for performance.

  • max_neighbors — buffer size per particle (auto-estimated if omitted).

  • secondary_collider_type — any registered collider except another "NeighborList".

This design works because every collider exposes create_neighbor_list. A NeighborList wrapping another NeighborList is not meaningful and should be avoided.

When a rebuild occurs, ordering may change according to the secondary collider’s sorting behaviour.

system_nl = jdem.System.create(
    state_p.shape,
    collider_type="NeighborList",
    collider_kw={
        "state": state_p,
        "cutoff": 2.0,
        "skin": 0.1,
        "secondary_collider_type": "CellList",
        "secondary_collider_kw": {"state": state_p},
        "max_neighbors": 8,
    },
)
print("Neighbor list collider:", type(system_nl.collider).__name__)
print("Cutoff:", float(getattr(system_nl.collider, "cutoff", jnp.nan)))
print("Skin:", float(getattr(system_nl.collider, "skin", jnp.nan)))
print("Max neighbors:", getattr(system_nl.collider, "max_neighbors", "n/a"))
print("Number of builds:", getattr(system_nl.collider, "n_build_times", "n/a"))
print("Last build overflow:", bool(getattr(system_nl.collider, "overflow", False)))
Neighbor list collider: NeighborList
Cutoff: 2.0
Skin: 0.2
Max neighbors: 3
Number of builds: 0
Last build overflow: False

Computing Potential Energy#

The collider exposes compute_potential_energy(), which sums all pairwise interaction energies as defined by the force model, per particle:

state_pe = jdem.State.create(
    pos=jnp.array([[0.0, 0.0], [1.5, 0.0]]),
    rad=jnp.array([1.0, 1.0]),
)
system_pe = jdem.System.create(state_pe.shape, force_model_type="spring")

pe = system_pe.collider.compute_potential_energy(state_pe, system_pe)
print("Per particle PE energy:", pe)
Per particle PE energy: 1250.0

How the Collider Fits in the Step Pipeline#

During each integration step, the pipeline is:

  1. Domain — applies boundary conditions.

  2. Integrator (before force) — advances positions a half-step.

  3. Collider — evaluates pairwise forces and writes state.force / state.torque.

  4. Force manager — adds gravity, external forces, custom force functions, and performs rigid-body aggregation.

  5. Integrator (after force) — advances velocities.

The collider only writes the pairwise contact contributions and resets forces; the force manager then adds everything else on top.

Total running time of the script: (0 minutes 1.716 seconds)