jaxdem.colliders#

Collision-detection interfaces and implementations.

Functions

valid_interaction_mask(clump_i, clump_j, ...)

Pair mask shared by all colliders.

Classes

Collider(*, overflow)

The base interface for defining how contact detection and force computations are performed in a simulation.

class jaxdem.colliders.Collider(*, overflow: Array = <factory>)#

Bases: Factory, ABC

The base interface for defining how contact detection and force computations are performed in a simulation.

Concrete subclasses of Collider implement the specific algorithms for calculating the interactions.

Notes:#

Self-interaction (i.e., calling the force/energy computation for i=j) is allowed, and the underlying force_model is responsible for correctly handling or ignoring this case.

Example:#

To define a custom collider, inherit from Collider, register it and implement its abstract methods:

>>> @Collider.register("CustomCollider")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class CustomCollider(Collider):
        ...

Then, instantiate it:

>>> jaxdem.Collider.create("CustomCollider", **custom_collider_kw)
overflow: Array#

Boolean flag indicating if a collider overflow occurred.

static compute_force(state: State, system: System) tuple[State, System][source]#

Abstract method to compute the total force acting on each particle in the simulation.

Implementations should calculate inter-particle forces and torques based on the current state and system configuration, then update the force and torque attributes of the state object with the resulting total force and torque for each particle.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object (with computed forces) and the System object.

Return type:

Tuple[State, System]

static compute_potential_energy(state: State, system: System) jax.Array[source]#

Compute the total (scalar) non-bonded potential energy of the system.

Implementations sum every pair-interaction contribution defined by system.force_model and return a single scalar. Pair energies are accumulated with the standard 0.5 factor so each pair counts once even when the underlying neighbor list visits (i, j) and (j, i) separately.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A scalar JAX array (shape ()) — the total non-bonded potential energy of the system.

Return type:

jax.Array

Example

>>> potential_energy = system.collider.compute_potential_energy(state, system)
>>> print(f"Total potential energy: {float(potential_energy):.4f}")
>>> print(potential_energy.shape)  # ()
static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#

Build a neighbor list for the current collider.

This is primarily used by neighbor-list-based algorithms and diagnostics. Implementations should match the cell-list semantics:

  • Returns a neighbor list of shape (N, max_neighbors) padded with -1.

  • Neighbor indices must refer to the returned (possibly sorted) state.

  • Also returns an overflow boolean flag (True if any particle exceeded max_neighbors neighbors within the cutoff).

static create_cross_neighbor_list(pos_a: jax.Array, pos_b: jax.Array, system: System, cutoff: float, max_neighbors: int) tuple[jax.Array, jax.Array][source]#

Build a cross-neighbor list between two sets of positions.

For each point in pos_a, finds all neighbors from pos_b within the given cutoff distance. This is useful for coupling different particle systems or computing interactions between distinct sets of objects.

The default implementation uses a naive \(O(N_A \times N_B)\) all-pairs search. Subclasses may override this with more efficient algorithms.

Parameters:
  • pos_a (jax.Array) – Query positions, shape (N_A, dim).

  • pos_b (jax.Array) – Database positions, shape (N_B, dim).

  • system (System) – The configuration of the simulation (used for domain displacement).

  • cutoff (float) – Search radius.

  • max_neighbors (int) – Maximum number of neighbors to store per query point.

Returns:

A tuple containing:

  • neighbor_list: Array of shape (N_A, max_neighbors) containing indices into pos_b, padded with -1.

  • overflow: Boolean flag indicating if any query point exceeded max_neighbors neighbors within the cutoff.

Return type:

Tuple[jax.Array, jax.Array]

class jaxdem.colliders.DynamicCellList(neighbor_mask: Array, cell_size: Array, *, overflow: Array = <factory>)#

Bases: Collider

Implicit cell-list (spatial hashing) collider using dynamic while-loops.

This collider accelerates short-range pair interactions by partitioning the domain into a regular grid of cubic/square cells of side length cell_size. Each particle is assigned to a cell, particles are sorted by cell hash, and interactions are evaluated only against particles in the same or neighboring cells given by neighbor_mask.

This implementation does not use a fixed max_occupancy array padding. Instead, it uses a dynamic jax.lax.while_loop to iterate over the exact number of particles present in each neighboring cell.

The operation of this collider can be understood as the following nested loop:

for particle in particles: # parallel
    for hash in stencil(particle): # parallel
        while next_neighbor in cell(hash): # sequential
            ...

Because the innermost loop is evaluated sequentially, the computational cost is driven by the average cell occupancy rather than the maximum possible occupancy. This makes the total theoretical cost:

\[O(N \cdot \text{neighbor\_mask\_size} \cdot \langle K \rangle)\]

where \(\langle K \rangle\) is the average cell occupancy. To understand how this scales, let’s analyze the cost components:

  • Stencil size:

    The stencil size depends on the ratio between the cell size (\(L\)) and the radius of the largest particle (\(r_{max}\)).

    \[\text{neighbor\_mask\_size} = \left( 2\left\lceil \frac{2r_{max}}{L} \right\rceil + 1 \right)^{dim}\]
  • Average occupancy:

    The average number of particles that occupy a cell depends on the cell volume and the macroscopic number density (\(\rho\)):

    \[\langle K \rangle = \rho L^{dim}\]

To express this in terms of the local volume fraction \(\phi\) (the ratio of volume actually occupied by particles to the total cell volume) and our normalized cell size \(L^\prime = L/r_{max}\), we use the average particle volume \(\langle V \rangle\):

\[\langle K \rangle = \phi \frac{L^{dim}}{\langle V \rangle} = \phi \frac{(L^\prime r_{max})^{dim}}{\langle V \rangle}\]

Knowing that the volume of the largest particle is \(V_{max} = k_v r_{max}^{dim}\) (where \(k_v\) is the geometric volume factor, such as \(4\pi/3\) in 3D or \(\pi\) in 2D), we find the final theoretical cost:

\[\text{cost} \approx N \left( 2\left\lceil \frac{2}{L^\prime} \right\rceil + 1 \right)^{dim} \left( \frac{\phi}{k_v} \frac{V_{max}}{\langle V \rangle} (L^\prime)^{dim} \right)\]
  • The Polydispersity Advantage:

    In the static cell list, cost scales with the ratio of the largest to smallest particle volume (\(V_{max}/V_{min} \propto \alpha^{dim}\), where \(\alpha = r_{max}/r_{min}\)). In this dynamic list, the cost scales with the ratio of the largest to the average particle volume (\(V_{max}/\langle V \rangle\)). Thus, the severe \(O(\alpha^{dim})\) padding penalty is significantly reduced or offset.

Constructor Parameters#

  • cell_size: Linear size of the grid cells. A larger cell size reduces neighbor stencil size but increases cell occupancy (longer sequential loops). A smaller cell size reduces occupancy but expands the stencil exponentially, which increases compilation overhead. If None, defaults to \(2 r_{max}\) (for systems with low polydispersity \(\alpha < 2.5\)), or \(0.5 r_{max}\) (for highly polydisperse systems).

  • search_range: Neighborhood range in cell units. Dictates how many cells are searched along each dimension. If None, it is dynamically computed to guarantee that all potential contacts within \(2 r_{max}\) are visited. Setting this higher expands the search stencil.

  • box_size: Bounding dimensions of the physical domain. This is only needed when the physical box size is small compared with the cell size (to ensure the minimum grid size requirement of 2 * search_range + 1 cells per axis is met under periodic boundary conditions).

This collider is suitable for large systems with low to moderate polydispersity (\(\alpha < 2.5\)) and medium to high packing fractions. Highly polydisperse systems (\(\alpha \ge 3.0\)) or systems containing rigid clumps with large internal overlaps will reduce performance significantly. This is because overlaps artificially inflate the local cell occupancy \(\langle K \rangle\) far beyond the macroscopic physical volume fraction \(\phi\), leading to longer sequential loops and reduced GPU thread efficiency.

Complexity#

  • Time: \(O(N)\) - \(O(N \log N)\) from sorting, plus \(O(N \cdot M \cdot \langle K \rangle)\) for neighbor probing (M = neighbor_mask_size, \(\langle K \rangle\) = average occupancy). The state is close to sorted every frame.

  • Memory: \(O(N)\).

Notes

  • Batching with ``vmap``: If you use jax.vmap to evaluate multiple simulation environments simultaneously, be aware of JAX’s SIMD execution model. Because the innermost while loop executes sequentially, the loop must continue running for all environments in the batch until the environment with the highest local cell occupancy finishes its iterations. Consequently, the computational cost of a batched execution is bottlenecked by the single worst-case occupancy across the entire batch.

neighbor_mask: Array#

Integer offsets defining the neighbor stencil (M, dim).

cell_size: Array#

Linear size of a grid cell (scalar).

classmethod Create(state: State, cell_size: ArrayLike | None = None, search_range: ArrayLike | None = None, box_size: ArrayLike | None = None) Self[source]#

Creates a DynamicCellList instance based on the reference state.

Parameters:
  • state (State) – Reference state containing positions and radii.

  • cell_size (float, optional) – Grid cell size.

  • search_range (int, optional) – Number of neighboring cells to search.

  • box_size (ArrayLike, optional) – Bounding dimensions of physical box. Only needed when the box size is small compared with the cell size.

Returns:

A configured DynamicCellList instance.

Return type:

DynamicCellList

static compute_force(state: State, system: System) tuple[State, System][source]#

Computes pairwise contact forces and torques using DynamicCellList.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated state and unmodified system.

Return type:

Tuple[State, System]

static compute_potential_energy(state: State, system: System) jax.Array[source]#

Computes the total non-bonded potential energy of the system.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

Scalar potential energy.

Return type:

jax.Array

static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#

Creates a neighbor list of shape (N, max_neighbors) using DynamicCellList.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

  • cutoff (float) – Verlet search cutoff radius.

  • max_neighbors (int) – Static size of neighbor buffer per particle.

Returns:

Sorted state, system, neighbor list, and overflow flag.

Return type:

Tuple[State, System, jax.Array, jax.Array]

static create_cross_neighbor_list(pos_a: jax.Array, pos_b: jax.Array, system: System, cutoff: float, max_neighbors: int) tuple[jax.Array, jax.Array][source]#

Creates a cross-neighbor list between pos_a (query) and pos_b (database).

Parameters:
  • pos_a (jax.Array) – Query positions, shape (N_A, dim).

  • pos_b (jax.Array) – Database positions, shape (N_B, dim).

  • system (System) – The configuration of the simulation.

  • cutoff (float) – Verlet search cutoff radius.

  • max_neighbors (int) – Static size of neighbor buffer per particle.

Returns:

Cross-neighbor list of shape (N_A, max_neighbors) and overflow flag.

Return type:

Tuple[jax.Array, jax.Array]

class jaxdem.colliders.DynamicMultiCellList(cell_size: Array, max_hashes: int, *, overflow: Array = <factory>)#

Bases: Collider

Implicit multi-cell list (spatial hashing) collider using dynamic while-loop traversal.

This collider partitions the domain into a regular grid of cubic/square cells of side length cell_size. Unlike standard cell lists where cell size is bounded by the largest particle diameter to ensure immediate neighbor stencil searching, the multi-cell list allows particles to span multiple cells. Each particle is registered in every grid cell that overlaps its axis-aligned bounding box (AABB), up to a maximum of max_hashes cells.

This formulation decouples the grid cell size from the largest particle size, making it exceptionally well-suited for systems with extreme polydispersity or large aspect ratios.

When particles span multiple cells, a pair of overlapping particles \((i, j)\) will be present in the same cell in multiple grid locations. To avoid evaluating their interaction multiple times, a canonical cell check is performed. Let \(\mathbf{x}_i\) and \(\mathbf{x}_j\) be the positions of two particles. An interaction is evaluated in cell \(c\) if and only if:

\[c = \text{canonical\_hash}(\mathbf{x}_i, \mathbf{x}_j)\]

where the canonical hash is uniquely determined by the spatial coordinates of the interaction midpoint under the domain’s boundary conditions.

Runtime and Cost Analysis#

Let \(L\) be the cell size, \(r\) represent particle radii, and \(dim\) be the spatial dimension. The number of cells overlapped by a particle of radius \(r\) is given by:

\[M(r) \approx \left( \frac{2r}{L} + 1 \right)^{dim}\]

The maximum number of cell hashes a particle can occupy is bounded by the largest particle:

\[M_{max} \approx \left( \frac{2r_{max}}{L} + 1 \right)^{dim}\]

We set the static padding parameter max_hashes \(\ge M_{max}\). The spatial partitioning step hashes and sorts \(N \cdot \text{max\_hashes}\) cell-particle references, introducing a sorting cost of \(O(N \cdot \text{max\_hashes} \log(N \cdot \text{max\_hashes}))\).

The average occupancy of a grid cell (the average number of particle references occupying a cell), denoted as \(\langle K \rangle\), is:

\[\langle K \rangle = \frac{N \langle M \rangle}{N_{cells}} = \rho L^{dim} \langle M \rangle\]

where \(\rho = N / V_{domain}\) is the macroscopic number density and \(\langle M \rangle\) is the average number of cells overlapped by a particle. Expressing \(\rho\) in terms of the packing fraction \(\phi\) and the average particle volume \(\langle V \rangle\) (\(\rho = \phi / \langle V \rangle\)):

\[\langle K \rangle = \phi \frac{L^{dim}}{\langle V \rangle} \langle M \rangle\]

Since each particle \(i\) queries \(M(r_i)\) cells during traversal, the expected number of pairwise checks per particle of radius \(r_i\) is \(M(r_i) \langle K \rangle\). Averaged over all particles, the total contact detection cost scales as:

\[\text{cost} \approx N \langle M \rangle \langle K \rangle = N \phi \frac{L^{dim}}{\langle V \rangle} \langle M \rangle^2\]
  • Optimal Cell Size: There is a clear trade-off in selecting the cell size \(L\): - As \(L \to 0\), the number of overlapped cells \(\langle M \rangle \propto L^{-dim}\)

    explodes, increasing sorting complexity and the number of cells each particle queries.

    • As \(L \to \infty\), the cell occupancy \(\langle K \rangle\) increases, leading to larger sequential dynamic loops.

    The optimal cell size is typically chosen to be comparable to the median particle diameter (e.g. \(L \approx 2 r_{median}\)), which balances the two costs.

  • The Polydispersity Advantage: In a standard cell list, a single giant particle forces the cell size \(L \ge 2r_{max}\). In highly polydisperse systems where \(\alpha = r_{max}/r_{min} \gg 1\), this results in massive cells relative to the tiny particles, leading to extremely high cell occupancies and redundant distance checks. By contrast, DynamicMultiCellList allows \(L\) to remain small (scaled to \(r_{median}\)). Small particles occupy only \(1\) or \(2^{dim}\) cells, while large particles occupy many cells. This avoids the stencil explosion for small particles, keeping the average occupancy low and significantly outperforming standard cell lists at high polydispersity. However, due to the descrete nature of the grid, we dont pay the penalty of increasing polidispersity until r//L exceeds the cell size. Meaning that two systems with different polydispersity (but the same cell size) could have the same performance.

Constructor Parameters#

  • cell_size: Linear size of the grid cells. Smaller cell sizes allow finer-grained partitioning (lower cell occupancies) but cause large particles to overlap more cells (inflating the sorting array). A larger cell size reduces the cells overlapped by large particles but increases cell occupancies. If None, it defaults to the median particle diameter \(2 r_{median}\).

  • max_hashes: Static padding parameter representing the maximum number of grid cells a single particle is allowed to overlap. Must be set high enough to cover the largest particle’s overlap capacity. If None, it defaults to the geometric cell overlap limit: \(\text{max\_hashes} = (\lceil 2r_{max}/L \rceil + 1)^{dim}\). Setting this higher increases compilation time and memory allocation, while setting it too small results in clipping of the AABB, causing missed interactions.

This collider is suitable for systems with high polydispersity (\(\alpha \ge 3.0\)). It allows the cell size to remain small without suffering from stencil explosions. While clumps with large internal overlaps inflate insertion counts, the canonical midpoint deduplication ensures that each pair interaction is evaluated exactly once. However, the high density of overlapping constituent spheres still increases local cell occupancies \(\langle K \rangle\). It is less suitable for monodisperse systems with few overlaps where standard cell lists are faster.

Complexity#

  • Time: \(O(N \cdot M_{max} \log(N \cdot M_{max}))\) sorting overhead, plus \(O(N \cdot \langle M \rangle \cdot \langle K \rangle)\) traversal.

  • Memory: \(O(N \cdot M_{max})\) to store the padded spatial hashing table.

cell_size: Array#

Linear size of a grid cell (scalar).

max_hashes: int#

Static padding parameter representing the maximum number of grid cells a particle is allowed to overlap.

classmethod Create(state: State, cell_size: ArrayLike | None = None, max_hashes: int | None = None) Self[source]#

Creates a DynamicMultiCellList instance based on the reference state.

Parameters:
  • state (State) – Reference state containing positions and radii.

  • cell_size (float, optional) – Grid cell size.

  • max_hashes (int, optional) – Maximum cells a single particle can overlap.

Returns:

A configured DynamicMultiCellList instance.

Return type:

DynamicMultiCellList

static compute_force(state: State, system: System) tuple[State, System][source]#

Computes pairwise contact forces and torques using DynamicMultiCellList.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated state and unmodified system.

Return type:

Tuple[State, System]

static compute_potential_energy(state: State, system: System) jax.Array[source]#

Computes the total non-bonded potential energy of the system.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

Scalar potential energy.

Return type:

jax.Array

static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#

Creates a neighbor list of shape (N, max_neighbors) using DynamicMultiCellList.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

  • cutoff (float) – Verlet search cutoff radius.

  • max_neighbors (int) – Static size of neighbor buffer per particle.

Returns:

Unmodified state, system, neighbor list, and overflow flag.

Return type:

Tuple[State, System, jax.Array, jax.Array]

static create_cross_neighbor_list(pos_a: jax.Array, pos_b: jax.Array, system: System, cutoff: float, max_neighbors: int) tuple[jax.Array, jax.Array][source]#

Creates a cross-neighbor list between pos_a (query) and pos_b (database).

Parameters:
  • pos_a (jax.Array) – Query positions, shape (N_A, dim).

  • pos_b (jax.Array) – Database positions, shape (N_B, dim).

  • system (System) – The configuration of the simulation.

  • cutoff (float) – Verlet search cutoff radius.

  • max_neighbors (int) – Static size of neighbor buffer per particle.

Returns:

Cross-neighbor list of shape (N_A, max_neighbors) and overflow flag.

Return type:

Tuple[jax.Array, jax.Array]

class jaxdem.colliders.NaiveSimulator(*, overflow: Array = <factory>)#

Bases: Collider

Implementation that computes forces and potential energies using a naive \(O(N^2)\) all-pairs loop.

This collider evaluates interactions between all particle pairs directly, without any spatial partitioning or binning.

The total force acting on particle \(i\) is the direct sum of its interactions with all other particles \(j\) in the system:

\[\mathbf{F}_i = \sum_{j=0}^{N-1} \mathbf{F}_{ij}(\mathbf{x}_i, \mathbf{x}_j, r_i, r_j) \cdot M_{ij}\]

where \(\mathbf{F}_{ij}\) is the force vector computed by the physical force model, and \(M_{ij}\) is the interaction eligibility mask determined by:

  • Clump member exclusions (internal clump particles do not exert forces on each other)

  • Bond connectivity exclusions

  • Contact overlap/cutoff checks

Runtime and Cost Analysis#

The total number of pair checks evaluated by this collider is fixed and equal to:

\[\text{cost} \approx N^2 \cdot C_{interaction}\]

where \(C_{interaction}\) represents the computational cost of a single pairwise force/energy query.

Because the algorithm does not partition space into cells or project coordinates onto axes, its execution time is completely independent of:

  • The spatial distribution or packing fraction \(\phi\) of the system

  • The particle polydispersity \(\alpha\)

  • Performance Trade-off: - For small systems (:math:`N le 10^3 - 2 cdot 10^3` depending on the GPU): NaiveSimulator is often the

    fastest collider because it requires zero sorting, hashing, or bookkeeping overhead, allowing perfect GPU thread utilization and minimal JIT compilation times.

    • For large systems (:math:`N ge 10^4`): The quadratic complexity \(O(N^2)\) leads to a severe performance bottleneck, making spatial partitioning colliders significantly faster.

Complexity#

  • Time: \(O(N^2)\).

  • Memory: \(O(N)\) (no auxiliary neighbor tables or grid structures are stored).

static compute_potential_energy(state: State, system: System) jax.Array[source]#

Computes the potential energy associated with each particle using a naive \(O(N^2)\) all-pairs loop.

This method iterates over all particle pairs (i, j) and sums the potential energy contributions computed by the system.force_model.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

Scalar containing the total potential energy of the system.

Return type:

jax.Array

static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#

Computes a neighbor list using a naive \(O(N^2)\) all-pairs search.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

  • cutoff (float) – The interaction radius (force cutoff).

  • max_neighbors (int) – Maximum number of neighbors to store per particle.

Returns:

A tuple containing: - state: The simulation state. - system: The simulation system. - neighbor_list: Array of shape (N, max_neighbors) containing neighbor indices. - overflow: Boolean flag indicating if any particle exceeded max_neighbors.

Return type:

Tuple[State, System, jax.Array, jax.Array]

static compute_force(state: State, system: System) tuple[State, System][source]#

Computes the total force acting on each particle using a naive \(O(N^2)\) all-pairs loop.

This method sums the force contributions from all particle pairs (i, j) as computed by the system.force_model and updates the particle forces.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object with computed forces and the unmodified System object.

Return type:

Tuple[State, System]

class jaxdem.colliders.NeighborList(secondary_collider: Collider, neighbor_list: Array, old_pos: Array, n_build_times: Array, cutoff: Array, skin: Array, max_neighbors: int, *, overflow: Array = <factory>)#

Bases: Collider

Implementation of a Verlet neighbor list collider.

Verlet neighbor lists cache candidate interaction pairs over multiple simulation timesteps. This bypasses the need to execute full spatial partitioning queries (sorting and slab/cell hashing) at every timestep, dramatically reducing contact detection overhead.

Mathematical Formalism & Rebuild Criteria#

The neighbor list is constructed with a search radius containing a buffer distance known as the skin:

\[r_{search} = \text{cutoff} + \text{skin}\]

Let \(\mathbf{x}_i^0\) represent the position of particle \(i\) at the time of the last neighbor list rebuild. At any subsequent timestep, the displacement of particle \(i\) from its reference position is:

\[\Delta \mathbf{x}_i = \mathbf{x}_i - \mathbf{x}_i^0\]

By the triangle inequality, the change in distance between any two particles \(i\) and \(j\) since the last rebuild is bounded by:

\[|d_{ij} - d_{ij}^0| \le \|\Delta \mathbf{x}_i\| + \|\Delta \mathbf{x}_j\| \le 2 \max_{k} \|\Delta \mathbf{x}_k\|\]

To guarantee that no pair of particles can come closer than the interaction range \(\text{cutoff}\) without being captured in the neighbor list, a rebuild is triggered as soon as:

\[\max_{k} \|\Delta \mathbf{x}_k\| > \frac{\text{skin}}{2}\]

Runtime and Cost Analysis#

The computational cost of simulations using neighbor lists consists of two parts:

  1. Rebuild Cost: Occurs occasionally when the maximum displacement threshold is exceeded. Any registerable collider (e.g., NaiveSimulator, DynamicCellList, DynamicMultiCellList, or SweepAndPruneShifted) can be configured and used to perform spatial queries during this rebuild phase. The complexity of the rebuild step is directly determined by the chosen underlying collider (e.g., \(O(N^2)\) for NaiveSimulator, \(O(N \log N)\) for DynamicCellList/DynamicMultiCellList, or \(O(N)\) for SweepAndPruneShifted).

  2. Step Evaluation Cost: Occurs at every timestep. We iterate directly over the static cached neighbor buffer of size max_neighbors.

    \[\text{cost}_{step} \approx N \cdot \text{max\_neighbors}\]
  • Estimating Buffer Size: The size of the neighbor buffer max_neighbors is estimated based on the search volume and number density:

    \[\text{max\_neighbors} \approx \gamma \cdot \rho \cdot V_{search}\]

    where \(\gamma\) is a safety factor (default 1.2), \(\rho = N / V_{domain} = \phi / \langle V \rangle\) is the macroscopic number density, and \(V_{search}\) is the volume of the search sphere of radius \(r_{search}\):

    \[\begin{split}V_{search} = \begin{cases} \pi r_{search}^2 & \text{in 2D} \\ \frac{4}{3}\pi r_{search}^3 & \text{in 3D} \end{cases}\end{split}\]

    Typically, a skin of \(0.1 \text{ to } 0.4\) times the particle diameter provides a good balance.

Constructor Parameters#

  • cutoff: The physical contact interaction range. Larger cutoffs increase the search volume exponentially, expanding the neighbor buffer.

  • skin: The buffer distance. Default is 0.05. Larger skin reduces rebuild frequency but inflates max_neighbors, increasing step time and memory.

  • max_neighbors: The static neighbor buffer size per particle. If not provided, it is estimated using safety factor and density heuristics. Setting this too small causes list overflows, while setting it too large wastes GPU memory.

  • number_density: Macroscopic number density used to estimate neighbor counts when not provided. Default is 1.0.

  • safety_factor: Multiplier applied to the estimated density to account for local fluctuations. Default is 1.2.

  • secondary_collider_type: The identifier of the underlying collider used to execute the spatial queries during rebuilds (e.g. "CellList", "sap_shifted", "naive", or "DynamicMultiCellList"). Any registered Collider subclass in the library can be used for the rebuild phase, allowing the rebuild cost to be optimized based on system characteristics.

  • secondary_collider_kw: Keyword args for the underlying collider constructor.

This collider is suitable for dense assemblies, static packings, slow shear flows, gravity settling, or any low-velocity systems. It is less suitable for high-speed granular flows or high-temperature systems where rapid particle motion triggers frequent neighbor list rebuilds, neutralizing the caching advantage. Furthermore, systems of rigid clumps with large overlaps require allocating larger neighbor buffers to accommodate excluded constituent pairs, which increases the memory footprint and the step traversal cost.

Temperature & Rebuild Frequency Discussion#

In particle systems (analogous to molecular dynamics), the “temperature” \(T\) is proportional to the mean squared velocity (kinetic energy) of the particles:

\[\langle v^2 \rangle \sim T \implies v_{rms} \propto \sqrt{T}\]

The rebuild criterion is triggered when the maximum particle displacement exceeds half the skin distance:

\[\max_k \|\Delta \mathbf{x}_k\| > \frac{\text{skin}}{2}\]

Approximating the particle displacement over time as \(\|\Delta \mathbf{x}\| \approx v \cdot t\), the average time interval between rebuilds \(\tau\) can be estimated as:

\[\tau \approx \frac{\text{skin}}{2 \cdot v_{rms}} \propto \frac{\text{skin}}{\sqrt{T}}\]

As a result, the rebuild frequency (\(f_{rebuild} = 1/\tau\)) scales as:

\[f_{rebuild} \propto \frac{\sqrt{T}}{\text{skin}}\]

In high-temperature systems, the rebuild frequency becomes extremely high, resulting in frequent executions of the \(O(N \log N)\) reconstruction. When \(f_{rebuild}\) approaches \(1\) (rebuilding every step), the neighbor list becomes slower than direct spatial partitioning colliders because of the redundant list buffering.

secondary_collider: Collider#

The underlying collider used to build the list via create_neighbor_list.

neighbor_list: Array#

Shape (N, max_neighbors). Contains the IDs of neighboring particles, padded with -1.

old_pos: Array#

Shape (N, dim). Positions of particles at the last build time.

n_build_times: Array#

Counter for how many times the list has been rebuilt.

cutoff: Array#

The interaction radius (force cutoff).

skin: Array#

Buffer distance. The list is built with radius = cutoff + skin and rebuilt when max_displacement > skin / 2.

max_neighbors: int#

Static buffer size for the neighbor list.

classmethod Create(state: State, cutoff: float, skin: float = 0.05, max_neighbors: int | None = None, number_density: float = 1.0, safety_factor: float = 1.2, secondary_collider_type: str = 'CellList', secondary_collider_kw: dict[str, Any] | None = None) Self[source]#

Creates a NeighborList collider.

Parameters:
  • state (State) – The initial simulation state used to determine system dimensions and particle count.

  • cutoff (float) – The physical interaction cutoff radius.

  • skin (float, default 0.05) – The buffer distance added to the cutoff for the neighbor list. Must be > 0.0 for performance.

  • max_neighbors (int, optional) – Maximum number of neighbors to store per particle. If not provided, it is estimated from the number_density.

  • number_density (float, default 1.0) – Number density of the system used to estimate max_neighbors if not explicitly provided.

  • safety_factor (float, default 1.2) – Multiplier applied to the estimated number of neighbors to account for fluctuations in local density.

  • secondary_collider_type (str, default "CellList") – Registered collider type used internally to build the neighbor lists.

  • secondary_collider_kw (dict[str, Any], optional) – Keyword arguments passed to the constructor of the internal collider. If None, cell_size is set to cutoff + skin.

Returns:

A configured NeighborList collider instance.

Return type:

NeighborList

static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#

Returns the current neighbor list from this collider.

This method refreshes the cached list when it has not been built yet or when any particle has moved farther than half the skin distance from the last build position. Otherwise it returns the cached neighbor_list and overflow flag stored in the collider.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

  • cutoff (float) – Ignored; the collider’s configured cutoff is used.

  • max_neighbors (int) – Ignored; the collider’s configured buffer size is used.

Returns:

A tuple containing:

  • state: The simulation state.

  • system: The simulation system.

  • neighbor_list: The cached neighbor list of shape (N, max_neighbors).

  • overflow: Boolean flag indicating if the list overflowed during the last build.

Return type:

Tuple[State, System, jax.Array, jax.Array]

Notes

  • The returned neighbor indices refer to the internal particle ordering established during the most recent rebuild inside compute_force.

static compute_force(state: State, system: System) tuple[State, System][source]#

Computes total forces acting on each particle, rebuilding the neighbor list if necessary.

This method checks if any particle has moved enough to trigger a rebuild (displacement > skin/2). If so, it invokes the internal spatial partitioner to refresh the neighbor list. It then sums force contributions using the cached list.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object with computed forces and the updated System object (with refreshed collider cache).

Return type:

Tuple[State, System]

static compute_potential_energy(state: State, system: System) jax.Array[source]#

Computes the potential energy associated with each particle using the cached neighbor list.

This method iterates over the cached neighbors for each particle and sums the potential energy contributions computed by the system.force_model.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

Scalar containing the total potential energy of the system.

Return type:

jax.Array

static create_cross_neighbor_list(pos_a: jax.Array, pos_b: jax.Array, system: System, cutoff: float, max_neighbors: int) tuple[jax.Array, jax.Array][source]#

Build a cross-neighbor list between two sets of positions.

Delegates to the internal secondary_collider’s create_cross_neighbor_list method.

Parameters:
  • pos_a (jax.Array) – Query positions, shape (N_A, dim).

  • pos_b (jax.Array) – Database positions, shape (N_B, dim).

  • system (System) – The configuration of the simulation (used for domain displacement).

  • cutoff (float) – Search radius.

  • max_neighbors (int) – Maximum number of neighbors to store per query point.

Returns:

A tuple containing:

  • neighbor_list: Array of shape (N_A, max_neighbors) containing indices into pos_b, padded with -1.

  • overflow: Boolean flag indicating if any query point exceeded max_neighbors neighbors within the cutoff.

Return type:

Tuple[jax.Array, jax.Array]

class jaxdem.colliders.SweepAndPrune(K: int, *, overflow: Array = <factory>)#

Bases: Collider

PCA-aligned 1D slab partitioning Sweep and Prune with shifted multi-pass approach.

This collider is an implementation of a variation of the Sweep and Prune algorithm. This variation is based on the paper:

Real-time Collision Culling of a Million Bodies on Graphics Processing Units, Liu et al. (2011)

Mathematical Formalism & Slab Offsets#

In \(dim\) dimensions, the space is partitioned into slabs of width \(bin\_size = 4 r_{max}\) along the perpendicular axes to the sweeping direction. To handle boundaries and prevent boundary-crossing particles from being missed, the algorithm performs \(2^{dim-1}\) parallel passes, each with a different perpendicular coordinate shift of \(2 r_{max}\) or \(0.0\).

In each pass \(p\), a particle’s perpendicular coordinates are mapped to a 1D cell index \(\text{HASH}_p\). To sweep all slabs in parallel without dynamic coordinate queries, the coordinates along the principal (sweeping) axis, \(x_{proj}\), are offset by:

\[x_{proj, shifted} = x_{proj} + \text{HASH}_p \cdot L_{proj}\]

where \(L_{proj} = L_{box} + 2 \cdot \text{cutoff} + 1.0\) is a spacing buffer ensuring that slabs are arranged end-to-end along a single concatenated line with no overlap. Particles are then sorted along this concatenated line, grouping all particles in the same slab together.

For each particle \(i\), we query a static window of \(2K\) neighboring indices in the sorted array (\(K\) to the left and \(K\) to the right). Interactions are evaluated only if both particles reside in the same cell slab and pass the canonical deduplication check to prevent duplicate calculations across different passes.

Runtime and Cost Analysis#

Since the search window \(2K\) is a static compilation parameter, the number of candidate checks performed per particle is fixed. The total pair evaluation cost scales as:

\[\text{cost} \approx N \cdot 2^{dim-1} \cdot 2K\]
  • Advantages: Unlike cell lists, the number of distance evaluations per particle is constant and completely independent of particle size, polydispersity, packing fraction, or cell size.

    Sorting complexity is \(O(2^{dim-1} \cdot N \log N)\) per time step, which is highly efficient as the state remains mostly sorted.

  • Polydispersity Penalty and Window Overflow: The safety of the collider (preventing missed contacts) relies on the window size \(K\) being large enough to cover all overlapping particles in the projected slab. Let \(W = 4r_{max}\) be the slab width. The expected number of particles in a search volume of length \(2r_{max}\) is:

    \[\lambda = \rho (2 \cdot W^{dim-1} \cdot r_{max}) = \frac{2 \cdot 4^{dim-1}}{k_v} \cdot \phi \cdot \frac{V_{max}}{\langle V \rangle}\]

    where \(\phi\) is the volume fraction, \(k_v\) is the geometric shape factor, and \(V_{max} / \langle V \rangle\) is the ratio of largest to average particle volume. For highly polydisperse systems where \(\alpha = r_{max}/r_{min} \gg 1\), the volume ratio scales as \(O(\alpha^{dim})\). This requires a large window size \(K\) to guarantee correctness, which increases the constant overhead of the search loops.

    However, due to the partition of space, the search window \(K\) does not depend on the system size like in the standard Sweep and Prune algorithm.

Constructor Parameters#

  • K: The static search window size (number of sorted neighbors to check in each direction). A larger K avoids candidate window overflow warnings in clustered regions but increases execution overhead. Default is 8.

This collider is suitable mid to high polydispersity systems. However, it is also suitable for systems with overlapping rigid clumps at the cost of increasing \(K\).

Complexity#

  • Time: \(O(2^{dim-1} \cdot N \log N)\) sorting, plus \(O(2^{dim} \cdot N \cdot K)\) traversal.

  • Memory: \(O(2^{dim} \cdot N \cdot K)\) to store candidate indices and masks.

K: int#

The static search window radius (number of sorted neighbors to check in each direction).

classmethod Create(state: State, K: int | None = None) Self[source]#

Creates a SweepAndPruneShifted instance based on the reference state.

Parameters:
  • state (State) – Reference state containing positions and radii.

  • max_neighbors (int, optional) – Ignored parameter (retained for signature compatibility).

  • K (int, optional) – Static search window radius size.

Returns:

A configured SweepAndPruneShifted instance.

Return type:

SweepAndPruneShifted

static compute_force(state: State, system: System) tuple[State, System][source]#

Computes pairwise contact forces and torques using SweepAndPruneShifted.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated state and unmodified system.

Return type:

Tuple[State, System]

static compute_potential_energy(state: State, system: System) jax.Array[source]#

Computes the total non-bonded potential energy of the system.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

Scalar potential energy.

Return type:

jax.Array

static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#

Creates a neighbor list of shape (N, max_neighbors) using SweepAndPruneShifted.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

  • cutoff (float) – Verlet search cutoff radius.

  • max_neighbors (int) – Static size of neighbor buffer per particle.

Returns:

Unmodified state, system, neighbor list, and overflow flag.

Return type:

Tuple[State, System, jax.Array, jax.Array]

static create_cross_neighbor_list(pos_a: jax.Array, pos_b: jax.Array, system: System, cutoff: float, max_neighbors: int) tuple[jax.Array, jax.Array][source]#

Creates a cross-neighbor list between pos_a (query) and pos_b (database).

Parameters:
  • pos_a (jax.Array) – Query positions, shape (N_A, dim).

  • pos_b (jax.Array) – Database positions, shape (N_B, dim).

  • system (System) – The configuration of the simulation.

  • cutoff (float) – Verlet search cutoff radius.

  • max_neighbors (int) – Static size of neighbor buffer per particle.

Returns:

Cross-neighbor list of shape (N_A, max_neighbors) and overflow flag.

Return type:

Tuple[jax.Array, jax.Array]

jaxdem.colliders.valid_interaction_mask(clump_i: Array, clump_j: Array, bond_id_i: Array, unique_id_j: Array) Array[source]#

Pair mask shared by all colliders.

Interactions are always disabled for particles in the same clump. Interactions for particles connected by a bond are also disabled.

Modules

cell_list

Cell List \(O(N log N)\) collider implementation.

multi_cell_list

Multi-Cell List collider implementations.

naive

Naive \(O(N^2)\) collider implementation.

neighbor_list

Neighbor List Collider implementation.

sweep_and_prune

Sweep and prune \(O(N log N)\) collider implementations.