jaxdem.colliders.neighbor_list#

Neighbor List Collider implementation.

Classes

NeighborList(secondary_collider, ...)

Implementation of a Verlet neighbor list collider.

class jaxdem.colliders.neighbor_list.NeighborList(secondary_collider: Collider, neighbor_list: Array, old_pos: Array, n_build_times: Array, cutoff: Array, skin: Array, max_neighbors: int, *, overflow: Array = <factory>)#

Bases: Collider

Implementation of a Verlet neighbor list collider.

Verlet neighbor lists cache candidate interaction pairs over multiple simulation timesteps. This bypasses the need to execute full spatial partitioning queries (sorting and slab/cell hashing) at every timestep, dramatically reducing contact detection overhead.

Mathematical Formalism & Rebuild Criteria#

The neighbor list is constructed with a search radius containing a buffer distance known as the skin:

\[r_{search} = \text{cutoff} + \text{skin}\]

Let \(\mathbf{x}_i^0\) represent the position of particle \(i\) at the time of the last neighbor list rebuild. At any subsequent timestep, the displacement of particle \(i\) from its reference position is:

\[\Delta \mathbf{x}_i = \mathbf{x}_i - \mathbf{x}_i^0\]

By the triangle inequality, the change in distance between any two particles \(i\) and \(j\) since the last rebuild is bounded by:

\[|d_{ij} - d_{ij}^0| \le \|\Delta \mathbf{x}_i\| + \|\Delta \mathbf{x}_j\| \le 2 \max_{k} \|\Delta \mathbf{x}_k\|\]

To guarantee that no pair of particles can come closer than the interaction range \(\text{cutoff}\) without being captured in the neighbor list, a rebuild is triggered as soon as:

\[\max_{k} \|\Delta \mathbf{x}_k\| > \frac{\text{skin}}{2}\]

Runtime and Cost Analysis#

The computational cost of simulations using neighbor lists consists of two parts:

  1. Rebuild Cost: Occurs occasionally when the maximum displacement threshold is exceeded. Any registerable collider (e.g., NaiveSimulator, DynamicCellList, DynamicMultiCellList, or SweepAndPruneShifted) can be configured and used to perform spatial queries during this rebuild phase. The complexity of the rebuild step is directly determined by the chosen underlying collider (e.g., \(O(N^2)\) for NaiveSimulator, \(O(N \log N)\) for DynamicCellList/DynamicMultiCellList, or \(O(N)\) for SweepAndPruneShifted).

  2. Step Evaluation Cost: Occurs at every timestep. We iterate directly over the static cached neighbor buffer of size max_neighbors.

    \[\text{cost}_{step} \approx N \cdot \text{max\_neighbors}\]
  • Estimating Buffer Size: The size of the neighbor buffer max_neighbors is estimated based on the search volume and number density:

    \[\text{max\_neighbors} \approx \gamma \cdot \rho \cdot V_{search}\]

    where \(\gamma\) is a safety factor (default 1.2), \(\rho = N / V_{domain} = \phi / \langle V \rangle\) is the macroscopic number density, and \(V_{search}\) is the volume of the search sphere of radius \(r_{search}\):

    \[\begin{split}V_{search} = \begin{cases} \pi r_{search}^2 & \text{in 2D} \\ \frac{4}{3}\pi r_{search}^3 & \text{in 3D} \end{cases}\end{split}\]

    Typically, a skin of \(0.1 \text{ to } 0.4\) times the particle diameter provides a good balance.

Constructor Parameters#

  • cutoff: The physical contact interaction range. Larger cutoffs increase the search volume exponentially, expanding the neighbor buffer.

  • skin: The buffer distance. Default is 0.05. Larger skin reduces rebuild frequency but inflates max_neighbors, increasing step time and memory.

  • max_neighbors: The static neighbor buffer size per particle. If not provided, it is estimated using safety factor and density heuristics. Setting this too small causes list overflows, while setting it too large wastes GPU memory.

  • number_density: Macroscopic number density used to estimate neighbor counts when not provided. Default is 1.0.

  • safety_factor: Multiplier applied to the estimated density to account for local fluctuations. Default is 1.2.

  • secondary_collider_type: The identifier of the underlying collider used to execute the spatial queries during rebuilds (e.g. "CellList", "sap_shifted", "naive", or "DynamicMultiCellList"). Any registered Collider subclass in the library can be used for the rebuild phase, allowing the rebuild cost to be optimized based on system characteristics.

  • secondary_collider_kw: Keyword args for the underlying collider constructor.

This collider is suitable for dense assemblies, static packings, slow shear flows, gravity settling, or any low-velocity systems. It is less suitable for high-speed granular flows or high-temperature systems where rapid particle motion triggers frequent neighbor list rebuilds, neutralizing the caching advantage. Furthermore, systems of rigid clumps with large overlaps require allocating larger neighbor buffers to accommodate excluded constituent pairs, which increases the memory footprint and the step traversal cost.

Temperature & Rebuild Frequency Discussion#

In particle systems (analogous to molecular dynamics), the “temperature” \(T\) is proportional to the mean squared velocity (kinetic energy) of the particles:

\[\langle v^2 \rangle \sim T \implies v_{rms} \propto \sqrt{T}\]

The rebuild criterion is triggered when the maximum particle displacement exceeds half the skin distance:

\[\max_k \|\Delta \mathbf{x}_k\| > \frac{\text{skin}}{2}\]

Approximating the particle displacement over time as \(\|\Delta \mathbf{x}\| \approx v \cdot t\), the average time interval between rebuilds \(\tau\) can be estimated as:

\[\tau \approx \frac{\text{skin}}{2 \cdot v_{rms}} \propto \frac{\text{skin}}{\sqrt{T}}\]

As a result, the rebuild frequency (\(f_{rebuild} = 1/\tau\)) scales as:

\[f_{rebuild} \propto \frac{\sqrt{T}}{\text{skin}}\]

In high-temperature systems, the rebuild frequency becomes extremely high, resulting in frequent executions of the \(O(N \log N)\) reconstruction. When \(f_{rebuild}\) approaches \(1\) (rebuilding every step), the neighbor list becomes slower than direct spatial partitioning colliders because of the redundant list buffering.

secondary_collider: Collider#

The underlying collider used to build the list via create_neighbor_list.

neighbor_list: Array#

Shape (N, max_neighbors). Contains the IDs of neighboring particles, padded with -1.

old_pos: Array#

Shape (N, dim). Positions of particles at the last build time.

n_build_times: Array#

Counter for how many times the list has been rebuilt.

cutoff: Array#

The interaction radius (force cutoff).

skin: Array#

Buffer distance. The list is built with radius = cutoff + skin and rebuilt when max_displacement > skin / 2.

max_neighbors: int#

Static buffer size for the neighbor list.

classmethod Create(state: State, cutoff: float, skin: float = 0.05, max_neighbors: int | None = None, number_density: float = 1.0, safety_factor: float = 1.2, secondary_collider_type: str = 'CellList', secondary_collider_kw: dict[str, Any] | None = None) Self[source]#

Creates a NeighborList collider.

Parameters:
  • state (State) – The initial simulation state used to determine system dimensions and particle count.

  • cutoff (float) – The physical interaction cutoff radius.

  • skin (float, default 0.05) – The buffer distance added to the cutoff for the neighbor list. Must be > 0.0 for performance.

  • max_neighbors (int, optional) – Maximum number of neighbors to store per particle. If not provided, it is estimated from the number_density.

  • number_density (float, default 1.0) – Number density of the system used to estimate max_neighbors if not explicitly provided.

  • safety_factor (float, default 1.2) – Multiplier applied to the estimated number of neighbors to account for fluctuations in local density.

  • secondary_collider_type (str, default "CellList") – Registered collider type used internally to build the neighbor lists.

  • secondary_collider_kw (dict[str, Any], optional) – Keyword arguments passed to the constructor of the internal collider. If None, cell_size is set to cutoff + skin.

Returns:

A configured NeighborList collider instance.

Return type:

NeighborList

static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#

Returns the current neighbor list from this collider.

This method refreshes the cached list when it has not been built yet or when any particle has moved farther than half the skin distance from the last build position. Otherwise it returns the cached neighbor_list and overflow flag stored in the collider.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

  • cutoff (float) – Ignored; the collider’s configured cutoff is used.

  • max_neighbors (int) – Ignored; the collider’s configured buffer size is used.

Returns:

A tuple containing:

  • state: The simulation state.

  • system: The simulation system.

  • neighbor_list: The cached neighbor list of shape (N, max_neighbors).

  • overflow: Boolean flag indicating if the list overflowed during the last build.

Return type:

Tuple[State, System, jax.Array, jax.Array]

Notes

  • The returned neighbor indices refer to the internal particle ordering established during the most recent rebuild inside compute_force.

static compute_force(state: State, system: System) tuple[State, System][source]#

Computes total forces acting on each particle, rebuilding the neighbor list if necessary.

This method checks if any particle has moved enough to trigger a rebuild (displacement > skin/2). If so, it invokes the internal spatial partitioner to refresh the neighbor list. It then sums force contributions using the cached list.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object with computed forces and the updated System object (with refreshed collider cache).

Return type:

Tuple[State, System]

static compute_potential_energy(state: State, system: System) jax.Array[source]#

Computes the potential energy associated with each particle using the cached neighbor list.

This method iterates over the cached neighbors for each particle and sums the potential energy contributions computed by the system.force_model.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

Scalar containing the total potential energy of the system.

Return type:

jax.Array

static create_cross_neighbor_list(pos_a: jax.Array, pos_b: jax.Array, system: System, cutoff: float, max_neighbors: int) tuple[jax.Array, jax.Array][source]#

Build a cross-neighbor list between two sets of positions.

Delegates to the internal secondary_collider’s create_cross_neighbor_list method.

Parameters:
  • pos_a (jax.Array) – Query positions, shape (N_A, dim).

  • pos_b (jax.Array) – Database positions, shape (N_B, dim).

  • system (System) – The configuration of the simulation (used for domain displacement).

  • cutoff (float) – Search radius.

  • max_neighbors (int) – Maximum number of neighbors to store per query point.

Returns:

A tuple containing:

  • neighbor_list: Array of shape (N_A, max_neighbors) containing indices into pos_b, padded with -1.

  • overflow: Boolean flag indicating if any query point exceeded max_neighbors neighbors within the cutoff.

Return type:

Tuple[jax.Array, jax.Array]