jaxdem.colliders#

Collision-detection interfaces and implementations.

Classes

Collider()

The base interface for defining how contact detection and force computations are performed in a simulation.

class jaxdem.colliders.Collider[source]#

Bases: Factory, ABC

The base interface for defining how contact detection and force computations are performed in a simulation.

Concrete subclasses of Collider implement the specific algorithms for calculating the interactions.

Notes

Self-interaction (i.e., calling the force/energy computation for i=j) is allowed, and the underlying force_model is responsible for correctly handling or ignoring this case.

Example

To define a custom collider, inherit from Collider, register it and implement its abstract methods:

>>> @Collider.register("CustomCollider")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class CustomCollider(Collider):
        ...

Then, instantiate it:

>>> jaxdem.Collider.create("CustomCollider", **custom_collider_kw)
static compute_force(state: State, system: System) Tuple[State, System][source][source]#

Abstract method to compute the total force acting on each particle in the simulation.

Implementations should calculate inter-particle forces and torques based on the current state and system configuration, then update the force and torque attributes of the state object with the resulting total force and torque for each particle.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object (with computed forces) and the System object.

Return type:

Tuple[State, System]

Note

  • This method donates state and system

static compute_potential_energy(state: State, system: System) jax.Array[source][source]#

Abstract method to compute the total potential energy of the system.

Implementations should calculate the sum per particle of all potential energies present in the system based on the current state and system configuration.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A scalar JAX array representing the total potential energy of each particle.

Return type:

jax.Array

Example

>>> potential_energy = system.collider.compute_potential_energy(state, system)
>>> print(f"Potential energy per particle: {potential_energy:.4f}")
>>> print(potential_energy.shape)  # (N,)
static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) Tuple[State, System, jax.Array, jax.Array][source][source]#

Build a neighbor list for the current collider.

This is primarily used by neighbor-list-based algorithms and diagnostics. Implementations should match the cell-list semantics:

  • Returns a neighbor list of shape (N, max_neighbors) padded with -1.

  • Neighbor indices must refer to the returned (possibly sorted) state.

  • Also returns an overflow boolean flag (True if any particle exceeded max_neighbors neighbors within the cutoff).

class jaxdem.colliders.NaiveSimulator[source]#

Bases: Collider

Implementation that computes forces and potential energies using a naive \(O(N^2)\) all-pairs interaction loop.

Notes

Due to its \(O(N^2)\) complexity, NaiveSimulator is suitable for simulations with a relatively small number of particles. For larger systems, a more efficient spatial partitioning collider should be used. However, this collider should be the fastest option for small systems (\(<1k-5k\) spheres depending on the GPU).

static compute_potential_energy(state: State, system: System) jax.Array[source][source]#

Computes the potential energy associated with each particle using a naive \(O(N^2)\) all-pairs loop.

This method iterates over all particle pairs (i, j) and sums the potential energy contributions computed by the system.force_model.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

One-dimensional array containing the total potential energy contribution for each particle.

Return type:

jax.Array

Note

  • This method donates state and system

static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) Tuple[State, System, jax.Array, jax.Array][source][source]#

Naive O(N^2) neighbor list build.

Matches the cell-list neighbor-list API: returns (state, system, neighbor_list, overflow) where neighbor indices refer to the returned state (unsorted for naive).

static compute_force(state: State, system: System) Tuple[State, System][source][source]#

Computes the total force acting on each particle using a naive \(O(N^2)\) all-pairs loop.

This method sums the force contributions from all particle pairs (i, j) as computed by the system.force_model and updates the particle forces.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object with computed forces and the unmodified System object.

Return type:

Tuple[State, System]

Note

  • This method donates state and system

class jaxdem.colliders.StaticCellList(neighbor_mask: Array, cell_size: Array, max_occupancy: int)[source]#

Bases: Collider

Implicit cell-list (spatial hashing) collider.

This collider accelerates short-range pair interactions by partitioning the domain into a regular grid of cubic/square cells of side length cell_size. Each particle is assigned to a cell, particles are sorted by cell hash, and interactions are evaluated only against particles in the same or neighboring cells given by neighbor_mask. The cell list is implicit because we never store per-cell particle lists explicitly; instead, we exploit the sorted hashes and fixed max_occupancy to probe neighbors in-place.

This collider is ideal for systems of spheres with minimum polydispersity and no dramatic overlaps. In this case, it might be even faster than the default cell list. However, it’s not recommended for systems with clumps, dramatic overlaps, as it might skip some contacts, or polydispersity, as it hinders the performance of this collider.

Complexity#

  • Time: \(O(N)\) - \(O(N \log N)\) from sorting, plus \(O(N M K)\) for neighbor probing (M = number of neighbor cells, K = max_occupancy). The state is close to sorted every frame.

  • Memory: \(O(N)\).

Notes

  • max_occupancy is an upper bound on particles per cell. If a cell contains more than this many particles, some interactions might be missed (you should choose cell_size and max_occupancy so this does not happen).

neighbor_mask: Array#

Integer offsets defining the neighbor stencil.

Shape is (M, dim), where each row is a displacement in cell coordinates. For search_range=1 in 2D this is the 3×3 Moore neighborhood (M=9); in 3D this is the 3×3×3 neighborhood (M=27).

cell_size: Array#

Linear size of a grid cell (scalar).

max_occupancy: int#

Maximum number of particles assumed to occupy a single cell.

The algorithm probes exactly max_occupancy entries starting from the first particle in a neighbor cell. This should be set high enough that real cells rarely exceed it; otherwise contacts/energy will be undercounted.

classmethod Create(state: State, cell_size: ArrayLike | None = None, search_range: ArrayLike | None = None, max_occupancy: int | None = None) Self[source][source]#

Creates a StaticCellList collider with robust defaults.

Defaults are chosen to avoid missing any contacts while keeping the neighbor stencil and assumed cell occupancy as small as possible given available information from state. For this we assume no overlap between spheres.

The cost of computing forces for one particle is determined by the number of neighboring cells to check and the occupancy of each cell. This cost can be estimated as:

\[\begin{split}\text{cost} = (2R + 1)^{dim} \cdot \text{max_occupancy} \\ \text{cost} = (2R + 1)^{dim} \cdot \left(\left\lceil \frac{L^{dim}}{V_{min}} \right\rceil +2 \right)\end{split}\]

where \(R\) is the search radius, \(L\) is the cell size, and \(V_{min}\) is the volume of the smallest element. We assume \(V_{min}\) to be the volume of the smallest sphere, without accounting for the packing fraction, to provide a conservative upper bound. The search radius \(R\) is computed as:

\[R = \left\lceil \frac{2 r_{max}}{L} \right\rceil\]

By default, we choose the options that yield the lowest computational cost: \(L = 2 \cdot r_{max}\) if \(\alpha < 2.5\), else \(L = r_{max}/2\).

The complexity of searching neighbors is \(O(N)\), where the choice of cell size and \(R\) attempts to minimize the constant factor. The constant factor grows with polydispersity (\(\alpha\)) as \(O(\alpha^{dim})\) with \(\alpha = r_{max}/r_{min}\). The cost for sorting and binary search remains \(O(N \log N)\).

Parameters:
  • state (State) – Reference state used to determine spatial dimension and default parameters.

  • cell_size (float, optional) – Cell edge length. If None, defaults to a value optimized for the radius distribution.

  • search_range (int, optional) – Neighbor range in cell units. If None, the smallest safe value is computed such that \(\text{search\_range} \cdot L \geq \text{cutoff}\).

  • max_occupancy (int, optional) – Assumed maximum particles per cell. If None, estimated from a conservative packing upper bound using the smallest radius.

Returns:

Configured collider instance.

Return type:

CellList

static compute_force(state: State, system: System) Tuple[State, System][source][source]#

Computes the total force acting on each particle using an implicit cell list \(O(N log N)\). This method sums the force contributions from all particle pairs (i, j) as computed by the system.force_model and updates the particle forces.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object with computed forces and the unmodified System object.

Return type:

Tuple[State, System]

static compute_potential_energy(state: State, system: System) jax.Array[source][source]#

Computes the potential energy acting on each particle using an implicit cell list \(O(N log N)\). This method sums the energy contributions from all particle pairs (i, j) as computed by the system.force_model.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

An array containing the potential energy for each particle.

Return type:

jax.Array

static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source][source]#

Computes the list of neighbors for each particle. The shape of the list is (N, max_neighbors). If a particle has less neighbors than max_neighbors, the list is padded with -1. The indices of the list correspond to the indices of the returned sorted state.

Note that no neighbors further than cell_size * (1 + search_range) (how many neighbors to check in the cells) can be found due to the nature of the cell list. If cutoff is greater than this value, the list might not return the expected list. Note that if a cell contains more spheres than those specified in max_occupancy, there might be missing neighbors.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

  • cutoff (float) – Search radius

  • max_neighbors (int) – Maximum number of neighbors to store per particle.

Returns:

The sorted state, the system, the neighbor list, and a boolean flag for overflow.

Return type:

tuple[State, System, jax.Array, jax.Array]

class jaxdem.colliders.DynamicCellList(neighbor_mask: Array, cell_size: Array)[source]#

Bases: Collider

Implicit cell-list (spatial hashing) collider using dynamic while-loops.

This collider accelerates short-range pair interactions by partitioning the domain into a regular grid. Unlike the static cell list, this implementation uses a dynamic jax.lax.while_loop to probe neighbor cells, which can be more efficient with polydisperse systems or low packing fractions. It’s also useful for systems that have a high occupancy per cell, for example, systems with clumps.

Complexity#

  • Time: \(O(N)\) - \(O(N \log N)\) from sorting, plus \(O(N M \langle K \rangle)\) for neighbor probing, where \(\langle K \rangle\) is the average cell occupancy. The state is close to sorted every frame.

  • Memory: \(O(N)\).

neighbor_mask: Array#

Integer offsets defining the neighbor stencil (M, dim).

cell_size: Array#

Linear size of a grid cell (scalar).

classmethod Create(state: State, cell_size: ArrayLike | None = None, search_range: ArrayLike | None = None, box_size: ArrayLike | None = None) Self[source][source]#

Creates a CellList collider with robust defaults.

Defaults are chosen to avoid missing any contacts while keeping the neighbor stencil and assumed cell occupancy as small as possible given available information from state.

The cost of computing forces for one particle is determined by the number of neighboring cells to check and the occupancy of each cell. This cost can be estimated as:

\[\begin{split}\text{cost} = (2R + 1)^{dim} \cdot \text{max_occupancy} \\ \text{cost} = (2R + 1)^{dim} \cdot \left(\left\lceil \frac{L^{dim}}{V_{min}} \right\rceil +2 \right)\end{split}\]

where \(R\) is the search radius, \(L\) is the cell size, and \(V_{min}\) is the volume of the smallest element. We assume \(V_{min}\) to be the volume of the smallest sphere, without accounting for the packing fraction, to provide a conservative upper bound. The search radius \(R\) is computed as:

\[R = \left\lceil \frac{2 r_{max}}{L} \right\rceil\]

By default, we choose the options that yield the lowest computational cost: \(L = 2 \cdot r_{max}\) if \(\alpha < 2.5\), else \(L = r_{max}/2\).

The complexity of searching neighbors is \(O(N)\), where the choice of cell size and \(R\) attempts to minimize the constant factor. The constant factor grows with polydispersity; however, the dynamic nature of the collider greatly minimizes polydispersity’s impact.

Parameters:
  • state (State) – Reference state used to determine spatial dimension and default parameters.

  • cell_size (float, optional) – Cell edge length. If None, defaults to a value optimized for the radius distribution.

  • box_size (jax.Array, optional) – Size of the periodic box used to ensure there are at least 3 cells per axis. If None, these checks are ignored and will lead to errors if violated.

  • search_range (int, optional) – Neighbor range in cell units. If None, the smallest safe value is computed such that \(\text{search\_range} \cdot L \geq \text{cutoff}\).

Returns:

Configured collider instance.

Return type:

CellList

static compute_force(state: State, system: System) Tuple[State, System][source][source]#

Computes the total force acting on each particle using an implicit cell list \(O(N log N)\). This method sums the force contributions from all particle pairs (i, j) as computed by the system.force_model and updates the particle forces.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object with computed forces and the unmodified System object.

Return type:

Tuple[State, System]

static compute_potential_energy(state: State, system: System) jax.Array[source][source]#

Computes the potential energy acting on each particle using an implicit cell list \(O(N log N)\). This method sums the energy contributions from all particle pairs (i, j) as computed by the system.force_model.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

An array containing the potential energy for each particle.

Return type:

jax.Array

static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) Tuple[State, System, jax.Array, jax.Array][source][source]#

Computes the list of neighbors for each particle. The shape of the list is (N, max_neighbors). If a particle has less neighbors than max_neighbors, the list is padded with -1. The indices of the list correspond to the indices of the returned sorted state.

Note that no neighbors further than cell_size * (1 + search_range) (how many neighbors to check in the cells) can be found due to the nature of the cell list. If cutoff is greater than this value, the list might not return the expected list.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

  • cutoff (float) – Search radius

  • max_neighbors (int) – Maximum number of neighbors to store per particle.

Returns:

The sorted state, the system, the neighbor list, and a boolean flag for overflow.

Return type:

tuple[State, System, jax.Array, jax.Array]

class jaxdem.colliders.NeighborList(cell_list: DynamicCellList, neighbor_list: Array, old_pos: Array, n_build_times: int, cutoff: Array, skin: Array, overflow: Array, max_neighbors: int)[source]#

Bases: Collider

Verlet Neighbor List collider.

This collider caches a list of neighbors for every particle. It only rebuilds the list when particles have moved more than half the ‘skin’ distance.

Performance Note: You must provide a non-zero skin (e.g., 0.1 * radius) for this collider to be efficient. If skin=0, it rebuilds every step.

cell_list#

The underlying spatial partitioner used to build the list.

Type:

DynamicCellList

neighbor_list#

Shape (N, max_neighbors). Contains the IDs of neighboring particles. padded with -1.

Type:

jax.Array

old_pos#

Shape (N, dim). Positions of particles at the last build time.

Type:

jax.Array

n_build_times#

Counter for how many times the list has been rebuilt.

Type:

int

cutoff#

The interaction radius (force cutoff).

Type:

float

skin#

Buffer distance. The list is built with radius = cutoff + skin and rebuilt when max_displacement > skin / 2.

Type:

float

overflow#

Boolean flag indicating if the neighbor list overflowed during build.

Type:

jax.Array

max_neighbors#

Static buffer size for the neighbor list.

Type:

int

cell_list: DynamicCellList#
neighbor_list: Array#
old_pos: Array#
n_build_times: int#
cutoff: Array#
skin: Array#
overflow: Array#
max_neighbors: int#
classmethod Create(state: State, cutoff: float, box_size: jax.Array | None = None, skin: float = 0.05, max_neighbors: int | None = None, number_density: float = 1.0, safety_factor: float = 1.2, cell_size: float | None = None) Self[source][source]#

Creates a NeighborList collider.

Parameters:
  • state (State) – Initial simulation state.

  • cutoff (float) – The physical interaction cutoff radius.

  • box_size (jax.Array, optional) – The size of the periodic box, if used.

  • skin (float, default 0.05) – The buffer distance. Must be > 0.0 for performance.

  • max_neighbors (int, optional) – Maximum neighbors to store per particle. If not provided, it is estimated from the number_density.

  • number_density (float, default 1.0) – Number density for the state used to calculate max_neighbors, if not provided. Assumed to be 1.0.

  • safety_factor (float, default 1.2) – Used to adjust the max_neighbors value calculated from number_density. Empirically obtained

  • cell_size (float, optional) – Override for the underlying cell list size.

static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) Tuple[State, System, jax.Array, jax.Array][source][source]#

Return the cached neighbor list from this collider.

Notes

  • This method does not rebuild the neighbor list. It simply returns the last cached neighbor_list and overflow stored in system.collider.

  • The returned neighbor indices refer to the collider’s internal particle ordering at the time the cache was last updated (i.e., after the most recent rebuild inside compute_force()).

  • The cutoff and max_neighbors arguments are accepted for API compatibility but are currently ignored; the cache was built using this collider’s configured cutoff + skin and max_neighbors.

static compute_force(state: State, system: System) Tuple[State, System][source][source]#
static compute_potential_energy(state: State, system: System) jax.Array[source][source]#

Modules

cell_list

Cell List \(O(N log N)\) collider implementation.

naive

Naive \(O(N^2)\) collider implementation.

neighbor_list

Neighbor List Collider implementation.

sweep_and_prune

Sweep and prune \(O(N log N)\) collider implementation.