jaxdem.colliders.cell_list#
Cell List \(O(N log N)\) collider implementation.
Classes
|
Implicit cell-list (spatial hashing) collider using dynamic while-loops. |
- class jaxdem.colliders.cell_list.DynamicCellList(neighbor_mask: Array, cell_size: Array, *, overflow: Array = <factory>)#
Bases:
ColliderImplicit cell-list (spatial hashing) collider using dynamic while-loops.
This collider accelerates short-range pair interactions by partitioning the domain into a regular grid of cubic/square cells of side length
cell_size. Each particle is assigned to a cell, particles are sorted by cell hash, and interactions are evaluated only against particles in the same or neighboring cells given byneighbor_mask.This implementation does not use a fixed
max_occupancyarray padding. Instead, it uses a dynamicjax.lax.while_loopto iterate over the exact number of particles present in each neighboring cell.The operation of this collider can be understood as the following nested loop:
for particle in particles: # parallel for hash in stencil(particle): # parallel while next_neighbor in cell(hash): # sequential ...
Because the innermost loop is evaluated sequentially, the computational cost is driven by the average cell occupancy rather than the maximum possible occupancy. This makes the total theoretical cost:
\[O(N \cdot \text{neighbor\_mask\_size} \cdot \langle K \rangle)\]where \(\langle K \rangle\) is the average cell occupancy. To understand how this scales, let’s analyze the cost components:
- Stencil size:
The stencil size depends on the ratio between the cell size (\(L\)) and the radius of the largest particle (\(r_{max}\)).
\[\text{neighbor\_mask\_size} = \left( 2\left\lceil \frac{2r_{max}}{L} \right\rceil + 1 \right)^{dim}\]
- Average occupancy:
The average number of particles that occupy a cell depends on the cell volume and the macroscopic number density (\(\rho\)):
\[\langle K \rangle = \rho L^{dim}\]
To express this in terms of the local volume fraction \(\phi\) (the ratio of volume actually occupied by particles to the total cell volume) and our normalized cell size \(L^\prime = L/r_{max}\), we use the average particle volume \(\langle V \rangle\):
\[\langle K \rangle = \phi \frac{L^{dim}}{\langle V \rangle} = \phi \frac{(L^\prime r_{max})^{dim}}{\langle V \rangle}\]Knowing that the volume of the largest particle is \(V_{max} = k_v r_{max}^{dim}\) (where \(k_v\) is the geometric volume factor, such as \(4\pi/3\) in 3D or \(\pi\) in 2D), we find the final theoretical cost:
\[\text{cost} \approx N \left( 2\left\lceil \frac{2}{L^\prime} \right\rceil + 1 \right)^{dim} \left( \frac{\phi}{k_v} \frac{V_{max}}{\langle V \rangle} (L^\prime)^{dim} \right)\]- The Polydispersity Advantage:
In the static cell list, cost scales with the ratio of the largest to smallest particle volume (\(V_{max}/V_{min} \propto \alpha^{dim}\), where \(\alpha = r_{max}/r_{min}\)). In this dynamic list, the cost scales with the ratio of the largest to the average particle volume (\(V_{max}/\langle V \rangle\)). Thus, the severe \(O(\alpha^{dim})\) padding penalty is significantly reduced or offset.
Constructor Parameters#
cell_size: Linear size of the grid cells. A larger cell size reduces neighbor stencil size but increases cell occupancy (longer sequential loops). A smaller cell size reduces occupancy but expands the stencil exponentially, which increases compilation overhead. If None, defaults to \(2 r_{max}\) (for systems with low polydispersity \(\alpha < 2.5\)), or \(0.5 r_{max}\) (for highly polydisperse systems).
search_range: Neighborhood range in cell units. Dictates how many cells are searched along each dimension. If None, it is dynamically computed to guarantee that all potential contacts within \(2 r_{max}\) are visited. Setting this higher expands the search stencil.
box_size: Bounding dimensions of the physical domain. This is only needed when the physical box size is small compared with the cell size (to ensure the minimum grid size requirement of 2 * search_range + 1 cells per axis is met under periodic boundary conditions).
This collider is suitable for large systems with low to moderate polydispersity (\(\alpha < 2.5\)) and medium to high packing fractions. Highly polydisperse systems (\(\alpha \ge 3.0\)) or systems containing rigid clumps with large internal overlaps will reduce performance significantly. This is because overlaps artificially inflate the local cell occupancy \(\langle K \rangle\) far beyond the macroscopic physical volume fraction \(\phi\), leading to longer sequential loops and reduced GPU thread efficiency.
Complexity#
Time: \(O(N)\) - \(O(N \log N)\) from sorting, plus \(O(N \cdot M \cdot \langle K \rangle)\) for neighbor probing (M =
neighbor_mask_size, \(\langle K \rangle\) = average occupancy). The state is close to sorted every frame.Memory: \(O(N)\).
Notes
Batching with ``vmap``: If you use
jax.vmapto evaluate multiple simulation environments simultaneously, be aware of JAX’s SIMD execution model. Because the innermostwhileloop executes sequentially, the loop must continue running for all environments in the batch until the environment with the highest local cell occupancy finishes its iterations. Consequently, the computational cost of a batched execution is bottlenecked by the single worst-case occupancy across the entire batch.
- neighbor_mask: Array#
Integer offsets defining the neighbor stencil (M, dim).
- cell_size: Array#
Linear size of a grid cell (scalar).
- classmethod Create(state: State, cell_size: ArrayLike | None = None, search_range: ArrayLike | None = None, box_size: ArrayLike | None = None) Self[source]#
Creates a DynamicCellList instance based on the reference state.
- Parameters:
state (State) – Reference state containing positions and radii.
cell_size (float, optional) – Grid cell size.
search_range (int, optional) – Number of neighboring cells to search.
box_size (ArrayLike, optional) – Bounding dimensions of physical box. Only needed when the box size is small compared with the cell size.
- Returns:
A configured DynamicCellList instance.
- Return type:
- static compute_force(state: State, system: System) tuple[State, System][source]#
Computes pairwise contact forces and torques using DynamicCellList.
- static compute_potential_energy(state: State, system: System) jax.Array[source]#
Computes the total non-bonded potential energy of the system.
- static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#
Creates a neighbor list of shape (N, max_neighbors) using DynamicCellList.
- Parameters:
- Returns:
Sorted state, system, neighbor list, and overflow flag.
- Return type:
- static create_cross_neighbor_list(pos_a: jax.Array, pos_b: jax.Array, system: System, cutoff: float, max_neighbors: int) tuple[jax.Array, jax.Array][source]#
Creates a cross-neighbor list between pos_a (query) and pos_b (database).
- Parameters:
pos_a (jax.Array) – Query positions, shape (N_A, dim).
pos_b (jax.Array) – Database positions, shape (N_B, dim).
system (System) – The configuration of the simulation.
cutoff (float) – Verlet search cutoff radius.
max_neighbors (int) – Static size of neighbor buffer per particle.
- Returns:
Cross-neighbor list of shape (N_A, max_neighbors) and overflow flag.
- Return type:
Tuple[jax.Array, jax.Array]