jaxdem.colliders#

Collision-detection interfaces and implementations.

Classes

Collider()

The base interface for defining how contact detection and force computations are performed in a simulation.

class jaxdem.colliders.Collider[source]#

Bases: Factory, ABC

The base interface for defining how contact detection and force computations are performed in a simulation.

Concrete subclasses of Collider implement the specific algorithms for calculating the interactions.

Notes

Self-interaction (i.e., calling the force/energy computation for i=j) is allowed, and the underlying force_model is responsible for correctly handling or ignoring this case.

Example

To define a custom collider, inherit from Collider, register it and implement its abstract methods:

>>> @Collider.register("CustomCollider")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class CustomCollider(Collider):
        ...

Then, instantiate it:

>>> jaxdem.Collider.create("CustomCollider", **custom_collider_kw)
static compute_force(state: State, system: System) Tuple['State', 'System'][source][source]#

Abstract method to compute the total force acting on each particle in the simulation.

Implementations should calculate inter-particle forces and torques based on the current state and system configuration, then update the force and torque attributes of the state object with the resulting total force and torque for each particle.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object (with computed forces) and the System object.

Return type:

Tuple[State, System]

Note

  • This method donates state and system

static compute_potential_energy(state: State, system: System) jax.Array[source][source]#

Abstract method to compute the total potential energy of the system.

Implementations should calculate the sum per particle of all potential energies present in the system based on the current state and system configuration.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A scalar JAX array representing the total potential energy of each particle.

Return type:

jax.Array

Example

>>> potential_energy = system.collider.compute_potential_energy(state, system)
>>> print(f"Potential energy per particle: {potential_energy:.4f}")
>>> print(potential_energy.shape") # (N, 1)
class jaxdem.colliders.NaiveSimulator[source]#

Bases: Collider

Implementation that computes forces and potential energies using a naive \(O(N^2)\) all-pairs interaction loop.

Notes

Due to its \(O(N^2)\) complexity, NaiveSimulator is suitable for simulations with a relatively small number of particles. For larger systems, a more efficient spatial partitioning collider should be used. However, this collider should be the fastest option for small systems (\(<1k-5k\) spheres depending on the GPU).

static compute_force(state: State, system: System) Tuple['State', 'System'][source][source]#

Computes the total force acting on each particle using a naive \(O(N^2)\) all-pairs loop.

This method sums the force contributions from all particle pairs (i, j) as computed by the system.force_model and updates the particle forces.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object with computed forces and the unmodified System object.

Return type:

Tuple[State, System]

Note

  • This method donates state and system

static compute_potential_energy(state: State, system: System) jax.Array[source][source]#

Computes the potential energy associated with each particle using a naive \(O(N^2)\) all-pairs loop.

This method iterates over all particle pairs (i, j) and sums the potential energy contributions computed by the system.force_model.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

One-dimensional array containing the total potential energy contribution for each particle.

Return type:

jax.Array

Note

  • This method donates state and system

class jaxdem.colliders.CellList(neighbor_mask: Array, cell_size: Array, max_occupancy: int)[source]#

Bases: Collider

Implicit cell-list (spatial hashing) collider.

This collider accelerates short-range pair interactions by partitioning the domain into a regular grid of cubic/square cells of side length cell_size. Each particle is assigned to a cell, particles are sorted by cell hash, and interactions are evaluated only against particles in the same or neighboring cells given by neighbor_mask. The cell list is implicit because we never store per-cell particle lists explicitly; instead, we exploit the sorted hashes and fixed max_occupancy to probe neighbors in-place.

Complexity#

  • Time: \(O(N \log N)\) from sorting, plus \(O(N M K)\) for neighbor probing (M = number of neighbor cells, K = max_occupancy).

  • Memory: \(O(N)\).

Notes

  • max_occupancy is an upper bound on particles per cell. If a cell contains more than this many particles, some interactions might be missed (you should choose cell_size and max_occupancy so this does not happen).

neighbor_mask: jax.Array#

Integer offsets defining the neighbor stencil.

Shape is (M, dim), where each row is a displacement in cell coordinates. For search_range=1 in 2D this is the 3×3 Moore neighborhood (M=9); in 3D this is the 3×3×3 neighborhood (M=27).

cell_size: jax.Array#

Linear size of a grid cell (scalar).

max_occupancy: int#

Maximum number of particles assumed to occupy a single cell.

The algorithm probes exactly max_occupancy entries starting from the first particle in a neighbor cell. This should be set high enough that real cells rarely exceed it; otherwise contacts/energy will be undercounted.

classmethod Create(state: State, cell_size: ArrayLike | None = None, search_range: ArrayLike | None = None, max_occupancy: ArrayLike | None = None) Self[source][source]#

Creates a CellList collider with robust defaults.

Defaults are chosen to avoid missing any contacts while keeping the neighbor stencil and assumed cell occupancy as small as possible given available information from state. For this we assume no overlap between spheres.

The cost of computing forces for one particle is determined by the number of neighboring cells to check and the occupancy of each cell. This cost can be estimated as:

\[\begin{split}\text{cost} = (2R + 1)^{dim} \cdot \text{max_occupancy} \\ \text{cost} = (2R + 1)^{dim} \cdot \left(\left\lceil \frac{L^{dim}}{V_{min}} \right\rceil +2 \right)\end{split}\]

where \(R\) is the search radius, \(L\) is the cell size, and \(V_{min}\) is the volume of the smallest element. We assume \(V_{min}\) to be the volume of the smallest sphere, without accounting for the packing fraction, to provide a conservative upper bound. The search radius \(R\) is computed as:

\[R = \left\lceil \frac{2 r_{max}}{L} \right\rceil\]

By default, we choose the options that yield the lowest computational cost: \(L = 2 \cdot r_{max}\) if \(\alpha < 2.5\), else \(L = r_{max}/2\).

The complexity of searching neighbors is \(O(N)\), where the choice of cell size and \(R\) attempts to minimize the constant factor. The constant factor grows with polydispersity (\(\alpha\)) as \(O(\alpha^{dim})\) with \(\alpha = r_{max}/r_{min}\). The cost for sorting and binary search remains \(O(N \log N)\).

Parameters:
  • state (State) – Reference state used to determine spatial dimension and default parameters.

  • cell_size (float, optional) – Cell edge length. If None, defaults to a value optimized for the radius distribution.

  • search_range (int, optional) – Neighbor range in cell units. If None, the smallest safe value is computed such that \(\text{search\_range} \cdot L \geq \text{cutoff}\).

  • max_occupancy (int, optional) – Assumed maximum particles per cell. If None, estimated from a conservative packing upper bound using the smallest radius.

Returns:

Configured collider instance.

Return type:

CellList

static compute_force(state: State, system: System) Tuple['State', 'System'][source][source]#

Computes the total force acting on each particle using an implicit cell list \(O(N log N)\). This method sums the force contributions from all particle pairs (i, j) as computed by the system.force_model and updates the particle forces.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object with computed forces and the unmodified System object.

Return type:

Tuple[State, System]

static compute_potential_energy(state: State, system: System) jax.Array[source][source]#

Computes the total force acting on each particle using an implicit cell list \(O(N log N)\). This method sums the force contributions from all particle pairs (i, j) as computed by the system.force_model and updates the particle forces.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object with computed forces and the unmodified System object.

Return type:

Tuple[State, System]

class jaxdem.colliders.DynamicCellList(neighbor_mask: Array, cell_size: Array, max_occupancy: int)[source]#

Bases: Collider

Implicit cell-list (spatial hashing) collider using dynamic while-loops.

This collider accelerates short-range pair interactions by partitioning the domain into a regular grid. Unlike the standard CellList, this implementation uses a dynamic jax.lax.while_loop to probe neighbor cells, which can be more efficient in systems with highly non-uniform particle distributions.

Complexity#

  • Time: \(O(N \log N)\) from sorting, plus \(O(N M \langle K \rangle)\) for neighbor probing, where \(\langle K \rangle\) is the average cell occupancy.

  • Memory: \(O(N)\).

neighbor_mask: jax.Array#

Integer offsets defining the neighbor stencil (M, dim).

cell_size: jax.Array#

Linear size of a grid cell (scalar).

max_occupancy: int#

Maximum number of particles assumed to occupy a single cell (loop safety limit).

classmethod Create(state: State, cell_size: ArrayLike | None = None, search_range: ArrayLike | None = None, max_occupancy: ArrayLike | None = None) Self[source][source]#

Creates a DynamicCellList collider with robust defaults.

static compute_force(state: State, system: System) Tuple['State', 'System'][source][source]#
static compute_potential_energy(state: State, system: System) jax.Array[source][source]#
class jaxdem.colliders.MaterializedCellList(neighbor_mask: Array, cell_size: Array, max_occupancy: int, grid_dims: Tuple[int, ...], strides: Array, num_cells: int)[source]#

Bases: Collider

Ultra-Optimized Explicitly materialized cell list collider.

Uses “dummy particle” padding and Wide Vectorization to process entire neighbor blocks in single operations, maximizing GPU parallel throughput.

neighbor_mask: jax.Array#
cell_size: jax.Array#
max_occupancy: int#
grid_dims: Tuple[int, ...]#
strides: jax.Array#
num_cells: int#
classmethod Create(state: State, box_size: ArrayLike, cell_size: ArrayLike | None = None, search_range: ArrayLike | None = None, max_occupancy: ArrayLike | None = None, periodic: bool = True) Self[source][source]#
static compute_force(state: State, system: System) Tuple['State', 'System'][source][source]#
static compute_potential_energy(state: State, system: System) jax.Array[source][source]#
class jaxdem.colliders.NeighborList(idx: Array, prev_pos: Array, did_buffer_overflow: Array, update_threshold: float, max_neighbors: int, cell_size: Array, neighbor_mask: Array, max_occupancy: int, grid_dims: Tuple[int, ...], strides: Array)[source]#

Bases: Collider

Neighbor List (Verlet List) collider following jax-md architectural patterns.

This implementation uses a persistent neighbor list state, handles buffer overflows, and uses periodic-aware displacement for rebuild triggering.

idx: jax.Array#

Neighbor indices of shape (N, max_neighbors). -1 or N indicates padding.

prev_pos: jax.Array#

Positions at the time of the last neighbor list rebuild.

did_buffer_overflow: jax.Array#

Boolean scalar indicating if the neighbor list was too small to hold all pairs.

update_threshold: float#

Verlet skin distance. Rebuilds occur when max displacement > update_threshold / 2.

max_neighbors: int#

Maximum number of neighbors stored per particle.

cell_size: jax.Array#
neighbor_mask: jax.Array#
max_occupancy: int#
grid_dims: Tuple[int, ...]#
strides: jax.Array#
classmethod Create(state: State, box_size: ArrayLike, max_neighbors: int = 64, update_threshold: float = 0.05, cell_size: ArrayLike | None = None, max_occupancy: ArrayLike | None = None, periodic: bool = True) Self[source][source]#
static compute_force(state: State, system: System) Tuple['State', 'System'][source][source]#
static compute_potential_energy(state: State, system: System) jax.Array[source][source]#

Modules

cell_list

Cell List \(O(N log N)\) collider implementation.

naive

Naive \(O(N^2)\) collider implementation.

sweep_and_prune

Sweep and prune \(O(N log N)\) collider implementation.