jaxdem.colliders.cell_list#
Cell List \(O(N log N)\) collider implementation.
Classes
|
Implicit cell-list (spatial hashing) collider using dynamic while-loops. |
|
Implicit cell-list (spatial hashing) collider. |
- class jaxdem.colliders.cell_list.DynamicCellList(neighbor_mask: Array, cell_size: Array)#
Bases:
ColliderImplicit cell-list (spatial hashing) collider using dynamic while-loops.
This collider accelerates short-range pair interactions by partitioning the domain into a regular grid of cubic/square cells of side length
cell_size. Each particle is assigned to a cell, particles are sorted by cell hash, and interactions are evaluated only against particles in the same or neighboring cells given byneighbor_mask.Unlike the static cell list, this implementation does not use a fixed
max_occupancyarray padding. Instead, it uses a dynamicjax.lax.while_loopto iterate over the exact number of particles present in each neighboring cell.The operation of this collider can be understood as the following nested loop:
for particle in particles: # parallel for hash in stencil(particle): # parallel while next_neighbor in cell(hash): # sequential ...
Because the innermost loop is evaluated sequentially, the computational cost is driven by the average cell occupancy rather than the maximum possible occupancy. This makes the total theoretical cost:
\[O(N \cdot \text{neighbor\_mask\_size} \cdot \langle K \rangle)\]where \(\langle K \rangle\) is the average cell occupancy. To understand how this scales, let’s analyze the cost components:
- Stencil size:
The stencil size depends on the ratio between the cell size (\(L\)) and the radius of the largest particle (\(r_{max}\)).
\[\text{neighbor\_mask\_size} = \left( 2\left\lceil \frac{2r_{max}}{L} \right\rceil + 1 \right)^{dim}\]
- Average occupancy:
The average number of particles that occupy a cell depends only on the cell volume and the macroscopic number density (\(\rho\)), completely independent of the smallest particle size.
\[\langle K \rangle = \rho L^{dim}\]
To express this in terms of the local volume fraction \(\phi\) (the ratio of volume actually occupied by particles to the total cell volume) and our normalized cell size \(L^\prime = L/r_{max}\), we use the average particle volume \(\langle V \rangle\):
\[\langle K \rangle = \phi \frac{L^{dim}}{\langle V \rangle} = \phi \frac{(L^\prime r_{max})^{dim}}{\langle V \rangle}\]Knowing that the volume of the largest particle is \(V_{max} = k_v r_{max}^{dim}\) (where \(k_v\) is the geometric volume factor, such as \(4\pi/3\) in 3D or \(\pi\) in 2D), we find the final theoretical cost:
\[\text{cost} \approx N \left( 2\left\lceil \frac{2}{L^\prime} \right\rceil + 1 \right)^{dim} \left( \frac{\phi}{k_v} \frac{V_{max}}{\langle V \rangle} (L^\prime)^{dim} \right)\]- The Polydispersity Advantage:
In the static cell list, cost scales with the ratio of the largest to smallest particle volume (\(V_{max}/V_{min} \propto \alpha^{dim}\), where \(\alpha = r_{max}/r_{min}\)). In this dynamic list, the cost scales with the ratio of the largest to the average particle volume (\(V_{max}/\langle V \rangle\)). Because adding tiny particles barely changes the average volume, the severe \(O(\alpha^{dim})\) padding penalty is significantly reduced or offset. The dynamic loop only iterates over particles that actually exist.
This collider is ideal for highly polydisperse systems, sparse systems (low packing fractions), or systems with rigid clumps that create massive local density spikes, as it completely avoids the memory bloat and wasted gather operations caused by array padding.
Complexity#
Time: \(O(N)\) - \(O(N \log N)\) from sorting, plus \(O(N \cdot M \cdot \langle K \rangle)\) for neighbor probing (M =
neighbor_mask_size, \(\langle K \rangle\) = average occupancy). The state is close to sorted every frame.Memory: \(O(N)\).
- neighbor_mask#
Integer offsets defining the neighbor stencil (M, dim).
- Type:
jax.Array
- cell_size#
Linear size of a grid cell (scalar).
- Type:
jax.Array
Notes
Batching with ``vmap``: If you use
jax.vmapto evaluate multiple simulation environments simultaneously, be aware of JAX’s SIMD execution model. Because the innermostwhileloop executes sequentially, the loop must continue running for all environments in the batch until the environment with the highest local cell occupancy finishes its iterations. Consequently, the computational cost of a batched execution is bottlenecked by the single worst-case occupancy across the entire batch.
- neighbor_mask: Array#
Integer offsets defining the neighbor stencil (M, dim).
- cell_size: Array#
Linear size of a grid cell (scalar).
- classmethod Create(state: State, cell_size: ArrayLike | None = None, search_range: ArrayLike | None = None, box_size: ArrayLike | None = None) Self[source]#
Creates a CellList collider with robust defaults.
Defaults are chosen to prevent missing contacts while keeping the neighbor stencil and assumed cell occupancy as small as possible given the available information from
state.Because this collider uses a dynamic while-loop, its optimal parameters differ slightly from the static implementation. The optimal cell size parameter, denoted here as the dimensionless ratio \(L^\prime = L/r_{max}\), is primarily driven by balancing the stencil size overhead against the sequential loop cost.
Standard density: The optimal cell size is typically \(L^\prime = 2\) (the diameter of the largest particle). This minimizes the search stencil to just the immediate 27 neighboring cells (in 3D), which is usually the most efficient balance for JAX compilations.
High polydispersity ( \(\alpha \gg 1\) ): Unlike the static cell list, the dynamic cell list handles high polydispersity gracefully. We generally maintain \(L^\prime = 2\) unless the size ratio is extreme, as shrinking the cell size drastically inflates the neighbor stencil, which harms the parallelized outer loops.
By default, if
cell_sizeis not provided, this method will infer an optimal value based on the radius distribution in the reference state.- Parameters:
state (State) – Reference state used to determine spatial dimension and default parameters.
cell_size (float, optional) – Cell edge length. If None, defaults to \(2 r_{max}\) for systems with low polydispersity (\(\alpha < 2.5\)), or \(0.5 r_{max}\) for highly polydisperse systems to balance stencil overhead.
search_range (int, optional) – Neighbor range in cell units. If None, the smallest safe value is computed such that \(\text{search\_range} \cdot L \geq \text{cutoff}\).
box_size (jax.Array, optional) – Size of the periodic box used to ensure there are at least
2 * search_range + 1cells per axis. If None, these checks are skipped.
- Returns:
Configured collider instance.
- Return type:
CellList
- static compute_force(state: State, system: System) tuple[State, System][source]#
Computes the total force acting on each particle using an implicit cell list \(O(N log N)\). This method sums the force contributions from all particle pairs (i, j) as computed by the
system.force_modeland updates the particle forces.
- static compute_potential_energy(state: State, system: System) jax.Array[source]#
Computes the potential energy acting on each particle using an implicit cell list \(O(N log N)\). This method sums the energy contributions from all particle pairs (i, j) as computed by the
system.force_model.
- static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#
Computes the list of neighbors for each particle. The shape of the list is (N, max_neighbors). If a particle has less neighbors than max_neighbors, the list is padded with -1. The indices of the list correspond to the indices of the returned sorted state.
Note that no neighbors further than cell_size * (1 + search_range) (how many neighbors to check in the cells) can be found due to the nature of the cell list. If cutoff is greater than this value, the list might not return the expected list.
- Parameters:
- Returns:
The sorted state, the system, the neighbor list, and a boolean flag for overflow.
- Return type:
- static create_cross_neighbor_list(pos_a: jax.Array, pos_b: jax.Array, system: System, cutoff: float, max_neighbors: int) tuple[jax.Array, jax.Array][source]#
Build a cross-neighbor list between two sets of positions using a dynamic cell list.
For each point in
pos_a, finds all neighbors frompos_bwithin the givencutoffdistance. Thepos_barray is hashed and sorted into cells, and the neighbor stencil of each query point inpos_ais used to probe the sortedpos_bhashes with a dynamicjax.lax.while_loop.No neighbors further than
cell_size * (1 + search_range)can be found due to the nature of the cell list.- Parameters:
pos_a (jax.Array) – Query positions, shape
(N_A, dim).pos_b (jax.Array) – Database positions, shape
(N_B, dim).system (System) – The configuration of the simulation.
cutoff (float) – Search radius.
max_neighbors (int) – Maximum number of neighbors to store per query point.
- Returns:
A tuple containing:
neighbor_list: Array of shape(N_A, max_neighbors)containing indices intopos_b, padded with-1.overflow: Boolean flag indicating if any query point exceededmax_neighborsneighbors within the cutoff.
- Return type:
Tuple[jax.Array, jax.Array]
- class jaxdem.colliders.cell_list.StaticCellList(neighbor_mask: Array, cell_size: Array, max_occupancy: int)#
Bases:
ColliderImplicit cell-list (spatial hashing) collider.
This collider accelerates short-range pair interactions by partitioning the domain into a regular grid of cubic/square cells of side length
cell_size. Each particle is assigned to a cell, particles are sorted by cell hash, and interactions are evaluated only against particles in the same or neighboring cells given byneighbor_mask. The cell list is implicit because we never store per-cell particle lists explicitly; instead, we exploit the sorted hashes and fixedmax_occupancyto probe neighbors in-place.The operation of this collider can be understood as the following nested loop:
for particle in particles: for hash in stencil(particle): for i in range(max_occupancy): ...
The loop is flattened and everything happens in parallel. We can do this because the stencil size and
max_occupancyare static values. This means that the cost of the calculations done with this cell list is:\[O(N \cdot \text{neighbor\_mask\_size} \cdot \text{max\_occupancy})\]Plus the cost of hashing \(O(N)\) and sorting \(O(N \log N)\). As we sort every time step, the system remains mainly sorted, reducing the practical sorting complexity to something closer to \(O(N)\). This fixed
max_occupancymakes the cell list ideal for certain types of systems but very bad for others. To understand the difference, let’s analyze the cost components:- Stencil size:
The stencil size depends on the ratio between the cell size (\(L\)) and the radius of the largest particle (\(r_{max}\)).
\[\text{neighbor\_mask\_size} = \left( 2\left\lceil \frac{2r_{max}}{L} \right\rceil + 1 \right)^{dim}\]
- Max occupancy:
The maximum number of particles that can occupy a cell depends on the cell volume, the density (represented as local volume fraction \(\phi\)), and the volume of the smallest particle (\(V_{min}\)). To prevent missing contacts during local density fluctuations, we estimate the expected occupancy (\(\lambda\)) and add a safety margin of 3 standard deviations:
\[\text{max\_occupancy} = \left\lceil \lambda + 3\sqrt{\lambda} \right\rceil \quad \text{where} \quad \lambda = \phi \frac{L^{dim}}{V_{min}}\]Here, \(\phi\) is the local volume fraction, representing the ratio of the volume actually occupied by particles to the total volume of the cell. For dense packings of spheres, \(\phi \approx 0.74\), but in systems with highly overlapping internal spheres (like rigid clumps), \(\phi\) can be much larger.
Putting this together and expressing the cell size in units of the maximum radius \(L^\prime = L/r_{max}\), we find the total theoretical cost to be:
\[\text{cost} \approx N \left( 2\left\lceil \frac{2}{L^\prime} \right\rceil + 1 \right)^{dim} \left\lceil \phi \frac{L^{dim}}{V_{min}} + 3\sqrt{\phi \frac{L^{dim}}{V_{min}}} \right\rceil\]Then, if we define polydispersity as the ratio between the largest and smallest particle \(\alpha = r_{max}/r_{min}\), we can express \(\lambda\) solely in terms of these dimensionless parameters. Knowing that the volume of the smallest particle is \(V_{min} = k_v r_{min}^{dim}\) (where \(k_v\) is the geometric volume factor, such as \(4\pi/3\) in 3D or \(\pi\) in 2D), we can write \(\lambda\) as:
\[\lambda = \frac{\phi}{k_v} (\alpha L^\prime)^{dim}\]Substituting this back, we find the final cost function:
\[\text{cost} \approx N \left( 2\left\lceil \frac{2}{L^\prime} \right\rceil + 1 \right)^{dim} \left\lceil \frac{\phi}{k_v} (\alpha L^\prime)^{dim} + 3\sqrt{\frac{\phi}{k_v}} (\alpha L^\prime)^{dim/2} \right\rceil\]- The Polydispersity Penalty:
Because we must safely bound the maximum possible number of particles in a cell, we base \(\lambda\) on the smallest particle. As shown in the equation above, the required array padding grows dramatically as \(O(\alpha^{dim})\), causing performance to degrade in highly polydisperse systems.
This collider is ideal for systems of spheres with minimum polydispersity and no dramatic overlaps. In those cases, it might be even faster than the dynamic cell list. However, it’s not recommended for systems with clumps (where internal overlaps cause extreme local \(\phi\)) or high polydispersity, as both drastically inflate the required
max_occupancypadding.Complexity#
Time: \(O(N)\) - \(O(N \log N)\) from sorting, plus \(O(N \cdot M \cdot K)\) for neighbor probing (M =
neighbor_mask_size, K =max_occupancy). The state is close to sorted every frame.Memory: \(O(N)\).
- cell_size#
Linear size of a grid cell (scalar).
- Type:
jax.Array
- max_occupancy#
Maximum number of particles assumed to occupy a single cell. The algorithm probes exactly
max_occupancyentries starting from the first particle in a neighbor cell. This should be set high enough that real cells rarely exceed it; otherwise, contacts and energy will be undercounted.- Type:
int
Notes
max_occupancyis an upper bound on particles per cell. If a cell contains more than this many particles, some interactions might be missed (you should carefully choosecell_sizeand ensure your local density does not exceed your expectedmax_occupancy).
- neighbor_mask: Array#
Integer offsets defining the neighbor stencil.
Shape is
(M, dim), where each row is a displacement in cell coordinates. Forsearch_range=1in 2D this is the 3×3 Moore neighborhood (M=9); in 3D this is the 3×3×3 neighborhood (M=27).
- cell_size: Array#
Linear size of a grid cell (scalar).
- max_occupancy: int#
Maximum number of particles assumed to occupy a single cell.
The algorithm probes exactly
max_occupancyentries starting from the first particle in a neighbor cell. This should be set high enough that real cells rarely exceed it; otherwise contacts/energy will be undercounted.
- classmethod Create(state: State, cell_size: ArrayLike | None = None, search_range: ArrayLike | None = None, box_size: ArrayLike | None = None, max_occupancy: int | None = None) Self[source]#
Creates a StaticCellList collider with robust defaults.
Defaults are chosen to prevent missing contacts while keeping the neighbor stencil and assumed cell occupancy as small as possible given the available information from
state.The optimal cell size parameter, denoted here as the dimensionless ratio \(L^\prime = L/r_{max}\), depends heavily on the system’s volume fraction (\(\phi\)) and polydispersity (\(\alpha = r_{max}/r_{min}\)):
Standard density ( \(\phi \le 1\) ): The optimal cell size is typically \(L^\prime = 2\) (the diameter of the largest particle). This minimizes the search stencil to just the immediate neighboring cells.
Extreme local density ( \(\phi \gg 1\) ): For systems with heavy internal overlaps (like rigid clumps), the optimal cell size shrinks. Values like \(L^\prime = 0.5\) or \(L^\prime = 0.25\) often yield better performance by reducing the massive array padding penalty, even at the cost of a larger search stencil.
High polydispersity ( \(\alpha \gg 1\) ): High polydispersity severely degrades performance regardless of density, because the fixed occupancy arrays must always be padded to accommodate the volume of the smallest particles.
By default, if
cell_sizeormax_occupancyare not provided, this method infers optimal safe values based on the radius distribution in the reference state, assuming a maximum local volume fraction of \(\phi = 1\). If your system contains clumps with internal overlaps where local \(\phi > 1\), you must override these defaults.- Parameters:
state (State) – Reference state used to determine spatial dimension and default parameters.
cell_size (float, optional) – Cell edge length. If None, defaults to \(2 r_{max}\) for systems with low polydispersity (\(\alpha < 2.5\)), or \(0.5 r_{max}\) for highly polydisperse systems to mitigate exponential array padding costs.
search_range (int, optional) – Neighbor range in cell units. If None, the smallest safe value is computed such that \(\text{search\_range} \cdot L \geq \text{cutoff}\).
box_size (jax.Array, optional) – Size of the periodic box used to ensure there are at least
2 * search_range + 1cells per axis. If None, these checks are skipped.max_occupancy (int, optional) – Assumed maximum particles per cell. If None, estimated using the statistical model: \(\lambda + 3\sqrt{\lambda}\), assuming a worst-case standard granular density of \(\phi = 1\) and the volume of the smallest particle.
- Returns:
Configured collider instance.
- Return type:
CellList
- static compute_force(state: State, system: System) tuple[State, System][source]#
Computes the total force acting on each particle using an implicit cell list \(O(N log N)\). This method sums the force contributions from all particle pairs (i, j) as computed by the
system.force_modeland updates the particle forces.
- static compute_potential_energy(state: State, system: System) jax.Array[source]#
Computes the potential energy acting on each particle using an implicit cell list \(O(N log N)\). This method sums the energy contributions from all particle pairs (i, j) as computed by the
system.force_model.
- static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#
Computes the list of neighbors for each particle. The shape of the list is (N, max_neighbors). If a particle has less neighbors than max_neighbors, the list is padded with -1. The indices of the list correspond to the indices of the returned sorted state.
Note that no neighbors further than cell_size * (1 + search_range) (how many neighbors to check in the cells) can be found due to the nature of the cell list. If cutoff is greater than this value, the list might not return the expected list. Note that if a cell contains more spheres than those specified in max_occupancy, there might be missing neighbors.
- Parameters:
- Returns:
The sorted state, the system, the neighbor list, and a boolean flag for overflow.
- Return type:
- static create_cross_neighbor_list(pos_a: jax.Array, pos_b: jax.Array, system: System, cutoff: float, max_neighbors: int) tuple[jax.Array, jax.Array][source]#
Build a cross-neighbor list between two sets of positions using an implicit cell list.
For each point in
pos_a, finds all neighbors frompos_bwithin the givencutoffdistance. Thepos_barray is hashed and sorted into cells, and the neighbor stencil of each query point inpos_ais used to probe the sortedpos_bhashes with a fixed-size unrolled loop.No neighbors further than
cell_size * (1 + search_range)can be found due to the nature of the cell list. If a cell contains more particles thanmax_occupancy, some neighbors may be missed.- Parameters:
pos_a (jax.Array) – Query positions, shape
(N_A, dim).pos_b (jax.Array) – Database positions, shape
(N_B, dim).system (System) – The configuration of the simulation.
cutoff (float) – Search radius.
max_neighbors (int) – Maximum number of neighbors to store per query point.
- Returns:
A tuple containing:
neighbor_list: Array of shape(N_A, max_neighbors)containing indices intopos_b, padded with-1.overflow: Boolean flag indicating if any query point exceededmax_neighborsneighbors within the cutoff.
- Return type:
Tuple[jax.Array, jax.Array]