jaxdem.colliders.sweep_and_prune#

Sweep and prune \(O(N log N)\) collider implementations.

Classes

DummyState(pos, rad, clump_id, bond_id, ...)

SweepAndPrune(K, *, overflow)

PCA-aligned 1D slab partitioning Sweep and Prune with shifted multi-pass approach.

class jaxdem.colliders.sweep_and_prune.DummyState(pos: 'jax.Array', rad: 'jax.Array', clump_id: 'jax.Array', bond_id: 'jax.Array', unique_id: 'jax.Array', N: 'int', dim: 'int')#

Bases: object

pos: Array#
rad: Array#
clump_id: Array#
bond_id: Array#
unique_id: Array#
N: int#
dim: int#
class jaxdem.colliders.sweep_and_prune.SweepAndPrune(K: int, *, overflow: Array = <factory>)#

Bases: Collider

PCA-aligned 1D slab partitioning Sweep and Prune with shifted multi-pass approach.

This collider is an implementation of a variation of the Sweep and Prune algorithm. This variation is based on the paper:

Real-time Collision Culling of a Million Bodies on Graphics Processing Units, Liu et al. (2011)

Mathematical Formalism & Slab Offsets#

In \(dim\) dimensions, the space is partitioned into slabs of width \(bin\_size = 4 r_{max}\) along the perpendicular axes to the sweeping direction. To handle boundaries and prevent boundary-crossing particles from being missed, the algorithm performs \(2^{dim-1}\) parallel passes, each with a different perpendicular coordinate shift of \(2 r_{max}\) or \(0.0\).

In each pass \(p\), a particle’s perpendicular coordinates are mapped to a 1D cell index \(\text{HASH}_p\). To sweep all slabs in parallel without dynamic coordinate queries, the coordinates along the principal (sweeping) axis, \(x_{proj}\), are offset by:

\[x_{proj, shifted} = x_{proj} + \text{HASH}_p \cdot L_{proj}\]

where \(L_{proj} = L_{box} + 2 \cdot \text{cutoff} + 1.0\) is a spacing buffer ensuring that slabs are arranged end-to-end along a single concatenated line with no overlap. Particles are then sorted along this concatenated line, grouping all particles in the same slab together.

For each particle \(i\), we query a static window of \(2K\) neighboring indices in the sorted array (\(K\) to the left and \(K\) to the right). Interactions are evaluated only if both particles reside in the same cell slab and pass the canonical deduplication check to prevent duplicate calculations across different passes.

Runtime and Cost Analysis#

Since the search window \(2K\) is a static compilation parameter, the number of candidate checks performed per particle is fixed. The total pair evaluation cost scales as:

\[\text{cost} \approx N \cdot 2^{dim-1} \cdot 2K\]
  • Advantages: Unlike cell lists, the number of distance evaluations per particle is constant and completely independent of particle size, polydispersity, packing fraction, or cell size.

    Sorting complexity is \(O(2^{dim-1} \cdot N \log N)\) per time step, which is highly efficient as the state remains mostly sorted.

  • Polydispersity Penalty and Window Overflow: The safety of the collider (preventing missed contacts) relies on the window size \(K\) being large enough to cover all overlapping particles in the projected slab. Let \(W = 4r_{max}\) be the slab width. The expected number of particles in a search volume of length \(2r_{max}\) is:

    \[\lambda = \rho (2 \cdot W^{dim-1} \cdot r_{max}) = \frac{2 \cdot 4^{dim-1}}{k_v} \cdot \phi \cdot \frac{V_{max}}{\langle V \rangle}\]

    where \(\phi\) is the volume fraction, \(k_v\) is the geometric shape factor, and \(V_{max} / \langle V \rangle\) is the ratio of largest to average particle volume. For highly polydisperse systems where \(\alpha = r_{max}/r_{min} \gg 1\), the volume ratio scales as \(O(\alpha^{dim})\). This requires a large window size \(K\) to guarantee correctness, which increases the constant overhead of the search loops.

    However, due to the partition of space, the search window \(K\) does not depend on the system size like in the standard Sweep and Prune algorithm.

Constructor Parameters#

  • K: The static search window size (number of sorted neighbors to check in each direction). A larger K avoids candidate window overflow warnings in clustered regions but increases execution overhead. Default is 8.

This collider is suitable mid to high polydispersity systems. However, it is also suitable for systems with overlapping rigid clumps at the cost of increasing \(K\).

Complexity#

  • Time: \(O(2^{dim-1} \cdot N \log N)\) sorting, plus \(O(2^{dim} \cdot N \cdot K)\) traversal.

  • Memory: \(O(2^{dim} \cdot N \cdot K)\) to store candidate indices and masks.

K: int#

The static search window radius (number of sorted neighbors to check in each direction).

classmethod Create(state: State, K: int | None = None) Self[source]#

Creates a SweepAndPruneShifted instance based on the reference state.

Parameters:
  • state (State) – Reference state containing positions and radii.

  • max_neighbors (int, optional) – Ignored parameter (retained for signature compatibility).

  • K (int, optional) – Static search window radius size.

Returns:

A configured SweepAndPruneShifted instance.

Return type:

SweepAndPruneShifted

static compute_force(state: State, system: System) tuple[State, System][source]#

Computes pairwise contact forces and torques using SweepAndPruneShifted.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated state and unmodified system.

Return type:

Tuple[State, System]

static compute_potential_energy(state: State, system: System) jax.Array[source]#

Computes the total non-bonded potential energy of the system.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

Scalar potential energy.

Return type:

jax.Array

static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#

Creates a neighbor list of shape (N, max_neighbors) using SweepAndPruneShifted.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

  • cutoff (float) – Verlet search cutoff radius.

  • max_neighbors (int) – Static size of neighbor buffer per particle.

Returns:

Unmodified state, system, neighbor list, and overflow flag.

Return type:

Tuple[State, System, jax.Array, jax.Array]

static create_cross_neighbor_list(pos_a: jax.Array, pos_b: jax.Array, system: System, cutoff: float, max_neighbors: int) tuple[jax.Array, jax.Array][source]#

Creates a cross-neighbor list between pos_a (query) and pos_b (database).

Parameters:
  • pos_a (jax.Array) – Query positions, shape (N_A, dim).

  • pos_b (jax.Array) – Database positions, shape (N_B, dim).

  • system (System) – The configuration of the simulation.

  • cutoff (float) – Verlet search cutoff radius.

  • max_neighbors (int) – Static size of neighbor buffer per particle.

Returns:

Cross-neighbor list of shape (N_A, max_neighbors) and overflow flag.

Return type:

Tuple[jax.Array, jax.Array]