Source code for jaxdem.colliders.cell_list

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Cell List :math:`O(N log N)` collider implementation."""

from __future__ import annotations

import jax
import jax.numpy as jnp
from jax.typing import ArrayLike

from dataclasses import dataclass, field, replace
from typing import Tuple, Optional, TYPE_CHECKING, cast
from functools import partial

try:  # Python 3.11+
    from typing import Self  # type: ignore[attr-defined]
except ImportError:  # pragma: no cover
    from typing_extensions import Self

from . import Collider
from ..utils.linalg import cross

if TYPE_CHECKING:  # pragma: no cover
    from ..state import State
    from ..system import System


# occupancy = ops.segment_sum(jnp.ones_like(hashes), hashes, cell_count)
# max_occupancy = jnp.max(occupancy)
# overflow = overflow | (max_occupancy > cell_capacity)


[docs] @Collider.register("CellList") @jax.tree_util.register_dataclass @dataclass(slots=True) class CellList(Collider): r""" Implicit cell-list (spatial hashing) collider. This collider accelerates short-range pair interactions by partitioning the domain into a regular grid of cubic/square cells of side length ``cell_size``. Each particle is assigned to a cell, particles are sorted by cell hash, and interactions are evaluated only against particles in the same or neighboring cells given by ``neighbor_mask``. The cell list is *implicit* because we never store per-cell particle lists explicitly; instead, we exploit the sorted hashes and fixed ``max_occupancy`` to probe neighbors in-place. Complexity ---------- - Time: :math:`O(N \log N)` from sorting, plus :math:`O(N M K)` for neighbor probing (M = number of neighbor cells, K = ``max_occupancy``). - Memory: :math:`O(N)`. Notes ----- - ``max_occupancy`` is an upper bound on particles per cell. If a cell contains more than this many particles, some interactions might be missed (you should choose ``cell_size`` and ``max_occupancy`` so this does not happen). """ neighbor_mask: jax.Array """ Integer offsets defining the neighbor stencil. Shape is ``(M, dim)``, where each row is a displacement in cell coordinates. For ``search_range=1`` in 2D this is the 3×3 Moore neighborhood (M=9); in 3D this is the 3×3×3 neighborhood (M=27). """ cell_size: jax.Array """ Linear size of a grid cell (scalar). """ max_occupancy: int = field(metadata={"static": True}) """ Maximum number of particles assumed to occupy a single cell. The algorithm probes exactly ``max_occupancy`` entries starting from the first particle in a neighbor cell. This should be set high enough that real cells rarely exceed it; otherwise contacts/energy will be undercounted. """
[docs] @classmethod def Create( cls, state: "State", cell_size: Optional[ArrayLike] = None, search_range: Optional[ArrayLike] = None, max_occupancy: Optional[ArrayLike] = None, ) -> Self: r""" Creates a CellList collider with robust defaults. Defaults are chosen to avoid missing any contacts while keeping the neighbor stencil and assumed cell occupancy as small as possible given available information from ``state``. For this we assume no overlap between spheres. The cost of computing forces for one particle is determined by the number of neighboring cells to check and the occupancy of each cell. This cost can be estimated as: .. math:: \text{cost} = (2R + 1)^{dim} \cdot \text{max_occupancy} \\ \text{cost} = (2R + 1)^{dim} \cdot \left(\left\lceil \frac{L^{dim}}{V_{min}} \right\rceil +2 \right) where :math:`R` is the search radius, :math:`L` is the cell size, and :math:`V_{min}` is the volume of the smallest element. We assume :math:`V_{min}` to be the volume of the smallest sphere, without accounting for the packing fraction, to provide a conservative upper bound. The search radius :math:`R` is computed as: .. math:: R = \left\lceil \frac{2 r_{max}}{L} \right\rceil By default, we choose the options that yield the lowest computational cost: :math:`L = 2 \cdot r_{max}` if :math:`\alpha < 2.5`, else :math:`L = r_{max}/2`. The complexity of searching neighbors is :math:`O(N)`, where the choice of cell size and :math:`R` attempts to minimize the constant factor. The constant factor grows with polydispersity (:math:`\alpha`) as :math:`O(\alpha^{dim})` with :math:`\alpha = r_{max}/r_{min}`. The cost for sorting and binary search remains :math:`O(N \log N)`. Parameters ---------- state : State Reference state used to determine spatial dimension and default parameters. cell_size : float, optional Cell edge length. If None, defaults to a value optimized for the radius distribution. search_range : int, optional Neighbor range in cell units. If None, the smallest safe value is computed such that :math:`\text{search\_range} \cdot L \geq \text{cutoff}`. max_occupancy : int, optional Assumed maximum particles per cell. If None, estimated from a conservative packing upper bound using the smallest radius. Returns ------- CellList Configured collider instance. """ min_rad = jnp.min(state.rad) max_rad = jnp.max(state.rad) alpha = max_rad / min_rad if cell_size is None: cell_size = 2.0 * max_rad if alpha < 2.5: cell_size = 2 * max_rad else: cell_size = max_rad / 2 if search_range is None: search_range = jnp.ceil(2 * max_rad / cell_size).astype(int) search_range = jnp.maximum(1, search_range) search_range = jnp.array(search_range, dtype=int) if max_occupancy is None: box_vol = cell_size**state.dim smallest_sphere_vol = jnp.array(0.0, dtype=float) if state.dim == 3: smallest_sphere_vol = (4.0 / 3.0) * jnp.pi * min_rad**3 / 0.9 elif state.dim == 2: smallest_sphere_vol = jnp.pi * min_rad**2 max_occupancy = jnp.ceil(box_vol / smallest_sphere_vol) + 2 r = jnp.arange(-search_range, search_range + 1, dtype=int) mesh = jnp.meshgrid(*([r] * state.dim), indexing="ij") neighbor_mask = jnp.stack([m.ravel() for m in mesh], axis=1) return cls( neighbor_mask=neighbor_mask.astype(int), cell_size=1.02 * jnp.asarray(cell_size, dtype=float), max_occupancy=int(max_occupancy), # type: ignore[arg-type] )
[docs] @staticmethod @partial(jax.jit, donate_argnames=("state", "system")) @partial(jax.named_call, name="CellList.compute_force") def compute_force(state: "State", system: "System") -> Tuple["State", "System"]: r""" Computes the total force acting on each particle using an implicit cell list :math:`O(N log N)`. This method sums the force contributions from all particle pairs (i, j) as computed by the ``system.force_model`` and updates the particle forces. Parameters ---------- state : State The current state of the simulation. system : System The configuration of the simulation. Returns ------- Tuple[State, System] A tuple containing the updated ``State`` object with computed forces and the unmodified ``System`` object. """ collider = cast(CellList, system.collider) iota = jax.lax.iota(dtype=int, size=state.N) MAX_OCCUPANCY = collider.max_occupancy pos = state.pos pos_p = state.q.rotate(state.q, state.pos_p) # to lab # 1. Determine Grid Dimensions # shape: (dim,) if system.domain.periodic: grid_dims = jnp.floor(system.domain.box_size / collider.cell_size).astype( int ) else: grid_dims = jnp.ceil(system.domain.box_size / collider.cell_size).astype( int ) # Compute strides (weights) for flattening 2D/3D indices to 1D hash # [1, nx, nx*ny, ...] strides = jnp.concatenate( [jnp.array([1], dtype=int), jnp.cumprod(grid_dims[:-1])] ) # 2. Calculate Particle Cell Indices cell_ids = jnp.floor((pos - system.domain.anchor) / collider.cell_size).astype( int ) # Wrap indices for hashing purposes if periodic # system.domain.periodic is a static variable. This is a compile time if if system.domain.periodic: cell_ids -= grid_dims * jnp.floor(cell_ids / grid_dims).astype(int) # 3. Spatial Hashing # shape (N,) particle_hash = jnp.dot(cell_ids, strides) # 4. Sort hashes and state particle_hash, perm = jax.lax.sort([particle_hash, iota], num_keys=1) state = jax.tree.map(lambda x: x[perm], state) cell_ids = cell_ids[perm] # 5. Precompute Neighbor Cell Hashes for every particle # (N, M, dim) = (N, 1, dim) + (1, M, dim) # M is number of neighbor cells (e.g., 27) current_cell = cell_ids[:, None, :] + collider.neighbor_mask if system.domain.periodic: current_cell -= grid_dims * jnp.floor(current_cell / grid_dims).astype(int) # shape (N,M) cell_hashes = jnp.dot(current_cell, strides) def per_particle( i: jax.Array, pos_pi: jax.Array, my_cell_id: jax.Array, cell_hash: jax.Array ) -> Tuple[jax.Array, jax.Array]: def per_neighbor_cell( current_cell_hash: jax.Array, ) -> Tuple[jax.Array, jax.Array]: # 1. Find Start Indices # Find where each neighbor hash starts in the sorted particle list. # 'searchsorted' returns the insertion index. # We do this inside the vmap to save memory (N, M) -> (N,) start_idx = jnp.searchsorted( particle_hash, current_cell_hash, side="left", method="scan_unrolled", ) def body_fun(offset: jax.Array) -> Tuple[jax.Array, jax.Array]: k = start_idx + offset safe_k = jnp.minimum(k, state.N - 1) valid = ( (k < state.N) * (particle_hash[safe_k] == current_cell_hash) * (state.ID[safe_k] != state.ID[i]) ) result = system.force_model.force(i, safe_k, state, system) forces, torques = jax.tree.map(lambda x: valid * x, result) torques += cross(pos_pi, forces) return forces, torques # VMAP over the fixed number of contacts result = jax.vmap(body_fun)(jax.lax.iota(size=MAX_OCCUPANCY, dtype=int)) return jax.tree.map(lambda x: x.sum(axis=0), result) # VMAP over neighbor cells result = jax.vmap(per_neighbor_cell)(cell_hash) return jax.tree.map(lambda x: x.sum(axis=0), result) # VMAP over all particles total_force, total_torque = jax.vmap(per_particle)( iota, pos_p, cell_ids, cell_hashes ) total_torque = jax.ops.segment_sum(total_torque, state.ID, num_segments=state.N) total_force = jax.ops.segment_sum(total_force, state.ID, num_segments=state.N) state.force += total_force[state.ID] state.torque += total_torque[state.ID] return state, system
[docs] @staticmethod @jax.jit @partial(jax.named_call, name="CellList.compute_potential_energy") def compute_potential_energy(state: "State", system: "System") -> jax.Array: r""" Computes the total force acting on each particle using an implicit cell list :math:`O(N log N)`. This method sums the force contributions from all particle pairs (i, j) as computed by the ``system.force_model`` and updates the particle forces. Parameters ---------- state : State The current state of the simulation. system : System The configuration of the simulation. Returns ------- Tuple[State, System] A tuple containing the updated ``State`` object with computed forces and the unmodified ``System`` object. """ collider = cast(CellList, system.collider) iota = jax.lax.iota(dtype=int, size=state.N) MAX_OCCUPANCY = collider.max_occupancy pos = state.pos # 1. Determine Grid Dimensions # shape: (dim,) if system.domain.periodic: grid_dims = jnp.floor(system.domain.box_size / collider.cell_size).astype( int ) else: grid_dims = jnp.ceil(system.domain.box_size / collider.cell_size).astype( int ) # Compute strides (weights) for flattening 2D/3D indices to 1D hash # [1, nx, nx*ny, ...] strides = jnp.concatenate( [jnp.array([1], dtype=int), jnp.cumprod(grid_dims[:-1])] ) # 2. Calculate Particle Cell Indices cell_ids = jnp.floor((pos - system.domain.anchor) / collider.cell_size).astype( int ) # Wrap indices for hashing purposes if periodic # system.domain.periodic is a static variable. This is a compile time if if system.domain.periodic: cell_ids -= grid_dims * jnp.floor(cell_ids / grid_dims).astype(int) # 3. Spatial Hashing # shape (N,) particle_hash = jnp.dot(cell_ids, strides) # 4. Sort hashes and state particle_hash, perm = jax.lax.sort([particle_hash, iota], num_keys=1) state = jax.tree.map(lambda x: x[perm], state) cell_ids = cell_ids[perm] # 5. Precompute Neighbor Cell Hashes for every particle # (N, M, dim) = (N, 1, dim) + (1, M, dim) # M is number of neighbor cells (e.g., 27) current_cell = cell_ids[:, None, :] + collider.neighbor_mask if system.domain.periodic: current_cell -= grid_dims * jnp.floor(current_cell / grid_dims).astype(int) # shape (N,M) cell_hashes = jnp.dot(current_cell, strides) def per_particle( i: jax.Array, my_cell_id: jax.Array, cell_hash: jax.Array ) -> jax.Array: def per_neighbor_cell( current_cell_hash: jax.Array, ) -> jax.Array: # 1. Find Start Indices # Find where each neighbor hash starts in the sorted particle list. # 'searchsorted' returns the insertion index. # We do this inside the vmap to save memory (N, M) -> (N,) start_idx = jnp.searchsorted( particle_hash, current_cell_hash, side="left", method="scan_unrolled", ) def body_fun(offset: jax.Array) -> jax.Array: k = start_idx + offset safe_k = jnp.minimum(k, state.N - 1) e_ij = system.force_model.energy(i, safe_k, state, system) valid = ( (k < state.N) * (particle_hash[safe_k] == current_cell_hash) * (state.ID[safe_k] != state.ID[i]) ) e_ij *= valid return 0.5 * e_ij # VMAP over the fixed number of contacts return jax.vmap(body_fun)( jax.lax.iota(size=MAX_OCCUPANCY, dtype=int) ).sum() # VMAP over neighbor cells return jax.vmap(per_neighbor_cell)(cell_hash).sum() # VMAP over all particles return jax.vmap(per_particle)(iota, cell_ids, cell_hashes)
# @staticmethod # @partial(jax.jit, inline=True) # def find_neighbors(state: "State", system: "System") -> jax.Array: # """ # Finds neighbors for ALL particles using a single global vectorized search. # Optimized Strategy: # ------------------- # 1. Calculate all (N, M) neighbor cell hashes at once. # 2. Perform ONE global `searchsorted` on the (N*M) query points. # 3. Broadcast the results to generate the (N, M*K) neighbor matrix. # Returns # ------- # jnp.ndarray # Matrix of neighbor indices with shape ``(N, M * MAX_OCCUPANCY)``. # """ # collider = cast(CellList, system.collider) # N = state.N # MAX_OCCUPANCY = collider.max_occupancy # # --- 1. Grid & Hash Setup --- # # Shape: (dim,) # grid_dims = jnp.ceil(system.domain.box_size / collider.cell_size).astype(int) # strides = jnp.concatenate( # [jnp.array([1], dtype=int), jnp.cumprod(grid_dims[:-1])] # ) # # Calculate cell IDs for all particles: (N, dim) # cell_ids = jnp.floor( # (state.pos - system.domain.anchor) / collider.cell_size # ).astype(int) # if system.domain.periodic: # cell_ids -= grid_dims * jnp.floor(cell_ids / grid_dims).astype(int) # # Particle hashes (sorted by caller assumption): (N,) # particle_hash = jnp.dot(cell_ids, strides) # # --- 2. Generate All Neighbor Queries Globally --- # # Broadcast add to get all neighbor cell coords for all particles # # (N, 1, dim) + (M, dim) -> (N, M, dim) # neighbor_cells = cell_ids[:, None, :] + collider.neighbor_mask # if system.domain.periodic: # neighbor_cells -= grid_dims * jnp.floor(neighbor_cells / grid_dims).astype( # int # ) # # Compute hashes for all N*M neighbor cells # # (N, M, dim) dot (dim,) -> (N, M) # neighbor_hashes = jnp.dot(neighbor_cells, strides) # # --- 3. Global Binary Search ("Preallocation" step) --- # # Instead of vmapping searchsorted, we pass the entire (N, M) array. # # JAX will execute this as a single kernel. # # Result `start_indices` has shape (N, M) # start_indices = jnp.searchsorted( # particle_hash, neighbor_hashes, side="left", method="scan_unrolled" # ) # # --- 4. Expand to Max Occupancy --- # # We need to generate indices [start, start+1, ..., start+K] # # Reshape start_indices for broadcasting: (N, M, 1) # # Add offset (K,): Result shape (N, M, K) # offset = jax.lax.iota(int, MAX_OCCUPANCY) # candidate_indices = start_indices[:, :, None] + offset # # --- 5. Validate & Mask (The Matrix Construction) --- # # A. Clip indices to avoid OOB reads during validation # safe_indices = jnp.minimum(candidate_indices, N - 1) # # B. Check 1: Is the candidate index within the particle list size? # # C. Check 2: Does the candidate actually belong to the target neighbor cell? # # (This handles cells that are not full) # # D. Check 3: Is the candidate NOT the particle itself? # # Expand particle_hash to map against candidates: (N, M, K) matches (N, M) queries # # Expand iota for self-check: (N, 1, 1) # iota_N = jax.lax.iota(int, N)[:, None, None] # is_valid = ( # (candidate_indices < N) # * (particle_hash[safe_indices] == neighbor_hashes[:, :, None]) # * (candidate_indices != iota_N) # ) # # Apply mask: Set invalid to -1 # # Shape: (N, M, K) # neighbors = jnp.where(is_valid, candidate_indices, -1) # # --- 6. Flatten --- # # Combine M neighbors cells * K occupancy -> (N, M*K) # return neighbors.reshape(N, -1)
[docs] @Collider.register("DynamicCellList") @jax.tree_util.register_dataclass @dataclass(slots=True) class DynamicCellList(Collider): r""" Implicit cell-list (spatial hashing) collider using dynamic while-loops. This collider accelerates short-range pair interactions by partitioning the domain into a regular grid. Unlike the standard CellList, this implementation uses a dynamic ``jax.lax.while_loop`` to probe neighbor cells, which can be more efficient in systems with highly non-uniform particle distributions. Complexity ---------- - Time: :math:`O(N \log N)` from sorting, plus :math:`O(N M \langle K \rangle)` for neighbor probing, where :math:`\langle K \rangle` is the average cell occupancy. - Memory: :math:`O(N)`. """ neighbor_mask: jax.Array """Integer offsets defining the neighbor stencil (M, dim).""" cell_size: jax.Array """Linear size of a grid cell (scalar).""" max_occupancy: int = field(metadata={"static": True}) """Maximum number of particles assumed to occupy a single cell (loop safety limit)."""
[docs] @classmethod def Create( cls, state: "State", cell_size: Optional[ArrayLike] = None, search_range: Optional[ArrayLike] = None, max_occupancy: Optional[ArrayLike] = None, ) -> Self: r""" Creates a DynamicCellList collider with robust defaults. """ min_rad = jnp.min(state.rad) max_rad = jnp.max(state.rad) alpha = max_rad / min_rad if cell_size is None: cell_size = 2.0 * max_rad if alpha < 2.5: cell_size = 2 * max_rad else: cell_size = max_rad / 2 if search_range is None: search_range = jnp.ceil(2 * max_rad / cell_size).astype(int) search_range = jnp.maximum(1, search_range) search_range = jnp.array(search_range, dtype=int) if max_occupancy is None: box_vol = cell_size**state.dim smallest_sphere_vol = jnp.array(0.0, dtype=float) if state.dim == 3: smallest_sphere_vol = (4.0 / 3.0) * jnp.pi * min_rad**3 / 0.9 elif state.dim == 2: smallest_sphere_vol = jnp.pi * min_rad**2 max_occupancy = jnp.ceil(box_vol / smallest_sphere_vol) + 2 r = jnp.arange(-search_range, search_range + 1, dtype=int) mesh = jnp.meshgrid(*([r] * state.dim), indexing="ij") neighbor_mask = jnp.stack([m.ravel() for m in mesh], axis=1) return cls( neighbor_mask=neighbor_mask.astype(int), cell_size=1.02 * jnp.asarray(cell_size, dtype=float), max_occupancy=int(max_occupancy), # type: ignore[arg-type] )
[docs] @staticmethod @partial(jax.jit, donate_argnames=("state", "system")) @partial(jax.named_call, name="DynamicCellList.compute_force") def compute_force(state: "State", system: "System") -> Tuple["State", "System"]: collider = cast(DynamicCellList, system.collider) iota = jax.lax.iota(dtype=int, size=state.N) MAX_OCCUPANCY = collider.max_occupancy pos = state.pos pos_p = state.q.rotate(state.q, state.pos_p) # to lab # 1. Determine Grid Dimensions # shape: (dim,) if system.domain.periodic: grid_dims = jnp.floor(system.domain.box_size / collider.cell_size).astype( int ) else: grid_dims = jnp.ceil(system.domain.box_size / collider.cell_size).astype( int ) # Compute strides (weights) for flattening 2D/3D indices to 1D hash # [1, nx, nx*ny, ...] strides = jnp.concatenate( [jnp.array([1], dtype=int), jnp.cumprod(grid_dims[:-1])] ) # 2. Calculate Particle Cell Indices cell_ids = jnp.floor((pos - system.domain.anchor) / collider.cell_size).astype( int ) # Wrap indices for hashing purposes if periodic # system.domain.periodic is a static variable. This is a compile time if if system.domain.periodic: cell_ids -= grid_dims * jnp.floor(cell_ids / grid_dims).astype(int) # 3. Spatial Hashing # shape (N,) particle_hash = jnp.dot(cell_ids, strides) # 4. Sort hashes and state particle_hash, perm = jax.lax.sort([particle_hash, iota], num_keys=1) state = jax.tree.map(lambda x: x[perm], state) cell_ids = cell_ids[perm] # 5. Precompute Neighbor Cell Hashes for every particle # (N, M, dim) = (N, 1, dim) + (1, M, dim) # M is number of neighbor cells (e.g., 27) current_cell = cell_ids[:, None, :] + collider.neighbor_mask if system.domain.periodic: current_cell -= grid_dims * jnp.floor(current_cell / grid_dims).astype(int) # shape (N,M) cell_hashes = jnp.dot(current_cell, strides) def per_particle( i: jax.Array, pos_pi: jax.Array, my_cell_id: jax.Array, cell_hash: jax.Array ) -> Tuple[jax.Array, jax.Array]: def per_neighbor_cell( current_cell_hash: jax.Array, ) -> Tuple[jax.Array, jax.Array]: start_idx = jnp.searchsorted( particle_hash, current_cell_hash, side="left", method="scan_unrolled", ) def cond_fun(val: Tuple[int, jax.Array, jax.Array]) -> jax.Array: k, _, _ = val return (k < state.N) * (current_cell_hash == particle_hash[k]) def body_fun( val: Tuple[int, jax.Array, jax.Array] ) -> Tuple[int, jax.Array, jax.Array]: k, acc_f, acc_t = val valid = state.ID[k] != state.ID[i] result = system.force_model.force(i, k, state, system) forces, torques = jax.tree.map(lambda x: valid * x, result) torques += cross(pos_pi, forces) return k + 1, acc_f + forces, acc_t + torques init_val = ( cast(int, start_idx), jnp.zeros(state.dim), jnp.zeros(state.dim), ) _, final_f, final_t = jax.lax.while_loop(cond_fun, body_fun, init_val) return final_f, final_t # VMAP over neighbor cells result = jax.vmap(per_neighbor_cell)(cell_hash) return jax.tree.map(lambda x: x.sum(axis=0), result) # VMAP over all particles total_force, total_torque = jax.vmap(per_particle)( iota, pos_p, cell_ids, cell_hashes ) total_torque = jax.ops.segment_sum(total_torque, state.ID, num_segments=state.N) total_force = jax.ops.segment_sum(total_force, state.ID, num_segments=state.N) state.force += total_force[state.ID] state.torque += total_torque[state.ID] return state, system
[docs] @staticmethod @jax.jit @partial(jax.named_call, name="DynamicCellList.compute_potential_energy") def compute_potential_energy(state: "State", system: "System") -> jax.Array: collider = cast(DynamicCellList, system.collider) iota = jax.lax.iota(dtype=int, size=state.N) MAX_OCCUPANCY = collider.max_occupancy pos = state.pos # 1. Determine Grid Dimensions # shape: (dim,) if system.domain.periodic: grid_dims = jnp.floor(system.domain.box_size / collider.cell_size).astype( int ) else: grid_dims = jnp.ceil(system.domain.box_size / collider.cell_size).astype( int ) # Compute strides (weights) for flattening 2D/3D indices to 1D hash # [1, nx, nx*ny, ...] strides = jnp.concatenate( [jnp.array([1], dtype=int), jnp.cumprod(grid_dims[:-1])] ) # 2. Calculate Particle Cell Indices cell_ids = jnp.floor((pos - system.domain.anchor) / collider.cell_size).astype( int ) # Wrap indices for hashing purposes if periodic # system.domain.periodic is a static variable. This is a compile time if if system.domain.periodic: cell_ids -= grid_dims * jnp.floor(cell_ids / grid_dims).astype(int) # 3. Spatial Hashing # shape (N,) particle_hash = jnp.dot(cell_ids, strides) # 4. Sort hashes and state particle_hash, perm = jax.lax.sort([particle_hash, iota], num_keys=1) state = jax.tree.map(lambda x: x[perm], state) cell_ids = cell_ids[perm] # 5. Precompute Neighbor Cell Hashes for every particle # (N, M, dim) = (N, 1, dim) + (1, M, dim) # M is number of neighbor cells (e.g., 27) current_cell = cell_ids[:, None, :] + collider.neighbor_mask if system.domain.periodic: current_cell -= grid_dims * jnp.floor(current_cell / grid_dims).astype(int) # shape (N,M) cell_hashes = jnp.dot(current_cell, strides) def per_particle( i: jax.Array, my_cell_id: jax.Array, cell_hash: jax.Array ) -> jax.Array: def per_neighbor_cell( current_cell_hash: jax.Array, ) -> jax.Array: # 1. Find Start Indices # Find where each neighbor hash starts in the sorted particle list. # 'searchsorted' returns the insertion index. # We do this inside the vmap to save memory (N, M) -> (N,) start_idx = jnp.searchsorted( particle_hash, current_cell_hash, side="left", method="scan_unrolled", ) def cond_fun(val: Tuple[int, jax.Array]) -> jax.Array: k, _ = val return (k < state.N) * (current_cell_hash == particle_hash[k]) def body_fun(val: Tuple[int, jax.Array]) -> Tuple[int, jax.Array]: k, acc_e = val valid = state.ID[k] != state.ID[i] e_ij = system.force_model.energy(i, k, state, system) return k + 1, acc_e + (0.5 * e_ij * valid) _, final_e = jax.lax.while_loop( cond_fun, body_fun, (cast(int, start_idx), jnp.array(0.0, dtype=float)), ) return final_e # VMAP over neighbor cells return jax.vmap(per_neighbor_cell)(cell_hash).sum() # VMAP over all particles return jax.vmap(per_particle)(iota, cell_ids, cell_hashes)
[docs] @Collider.register("MaterializedCellList") @jax.tree_util.register_dataclass @dataclass(slots=True) class MaterializedCellList(Collider): r""" Ultra-Optimized Explicitly materialized cell list collider. Uses "dummy particle" padding and Wide Vectorization to process entire neighbor blocks in single operations, maximizing GPU parallel throughput. """ neighbor_mask: jax.Array cell_size: jax.Array max_occupancy: int = field(metadata={"static": True}) grid_dims: Tuple[int, ...] = field(metadata={"static": True}) strides: jax.Array = field(metadata={"static": True}) num_cells: int = field(metadata={"static": True})
[docs] @classmethod def Create( cls, state: "State", box_size: ArrayLike, cell_size: Optional[ArrayLike] = None, search_range: Optional[ArrayLike] = None, max_occupancy: Optional[ArrayLike] = None, periodic: bool = True, ) -> Self: min_rad = jnp.min(state.rad) max_rad = jnp.max(state.rad) alpha = max_rad / min_rad if cell_size is None: cell_size = 2.0 * max_rad if alpha < 2.5 else max_rad / 2.0 if search_range is None: search_range = jnp.ceil(2 * max_rad / cell_size).astype(int) search_range = jnp.maximum(1, search_range) search_range = jnp.array(search_range, dtype=int) if max_occupancy is None: box_vol = cell_size**state.dim smallest_sphere_vol = ( (4.0 / 3.0) * jnp.pi * min_rad**3 / 0.9 if state.dim == 3 else jnp.pi * min_rad**2 ) max_occupancy = jnp.ceil(box_vol / smallest_sphere_vol) + 2 if periodic: grid_dims_val = jnp.floor(jnp.asarray(box_size) / cell_size).astype(int) else: grid_dims_val = jnp.ceil(jnp.asarray(box_size) / cell_size).astype(int) grid_dims = tuple(map(int, grid_dims_val)) num_cells = 1 for d in grid_dims: num_cells *= d strides = jnp.concatenate( [jnp.array([1], dtype=int), jnp.cumprod(jnp.array(grid_dims[:-1]))] ) r = jnp.arange(-search_range, search_range + 1, dtype=int) mesh = jnp.meshgrid(*([r] * state.dim), indexing="ij") neighbor_mask = jnp.stack([m.ravel() for m in mesh], axis=1) return cls( neighbor_mask=neighbor_mask.astype(int), cell_size=1.02 * jnp.asarray(cell_size, dtype=float), max_occupancy=int(max_occupancy), grid_dims=grid_dims, strides=strides, num_cells=int(num_cells), )
[docs] @staticmethod @partial(jax.jit, donate_argnames=("state", "system")) @partial(jax.named_call, name="MaterializedCellList.compute_force") def compute_force(state: "State", system: "System") -> Tuple["State", "System"]: collider = cast(MaterializedCellList, system.collider) N, MAX_OCC = state.N, collider.max_occupancy grid_dims = jnp.array(collider.grid_dims) strides = collider.strides iota = jax.lax.iota(dtype=int, size=N) pos_p = state.q.rotate(state.q, state.pos_p) # 1. Spatial Hash cell_indices = jnp.floor( (state.pos - system.domain.anchor) / collider.cell_size ).astype(int) if system.domain.periodic: cell_indices %= grid_dims else: cell_indices = jnp.clip(cell_indices, 0, grid_dims - 1) particle_hash = jnp.dot(cell_indices, strides) # 2. Sort particle_hash, perm = jax.lax.sort([particle_hash, iota], num_keys=1) # 3. Cumsum for slots is_first = jnp.concatenate( [jnp.array([True]), (particle_hash[1:] != particle_hash[:-1])] ) slots = ( jax.lax.associative_scan( lambda a, b: (jnp.where(b[1], b[0], a[0] + b[0]), a[1] | b[1]), (jnp.ones(N, int), is_first), )[0] - 1 ) # Build reordered state + Dummy Particle at index N def add_dummy(x): return jnp.concatenate([x[perm], jnp.expand_dims(x[perm[0]], 0)], axis=0) padded_state = jax.tree.map(add_dummy, state) padded_state = replace( padded_state, ID=padded_state.ID.at[N].set(-1), pos_c=padded_state.pos_c.at[N].set(jnp.array([1e10] * state.dim)), rad=padded_state.rad.at[N].set(0.0), ) padded_pos_p = jnp.concatenate([pos_p[perm], pos_p[perm[:1]]], axis=0) # 4. Building Cell Matrix with Dummy Padding (N = dummy index) valid_slot = slots < MAX_OCC flat_idx = particle_hash * MAX_OCC + jnp.where(valid_slot, slots, 0) cell_matrix_flat = jnp.full(collider.num_cells * MAX_OCC, N, dtype=int) cell_matrix = ( cell_matrix_flat.at[flat_idx] .set(jnp.where(valid_slot, jax.lax.iota(int, N), N)) .reshape(-1, MAX_OCC) ) # 5. Stencil Deduplication neighbor_offsets = collider.neighbor_mask if system.domain.periodic: wrapped_offsets = neighbor_offsets % grid_dims off_hashes = jnp.dot(wrapped_offsets, strides) first_occurrence = jnp.argmax( off_hashes[None, :] == off_hashes[:, None], axis=1 ) is_primary = jnp.arange(len(off_hashes)) == first_occurrence else: is_primary = jnp.ones(len(neighbor_offsets), dtype=bool) # 6. Precompute Neighbor Cell Hashes sorted_cell_ids = cell_indices[perm] neighbor_cells = sorted_cell_ids[:, None, :] + neighbor_offsets[None, :, :] if system.domain.periodic: neighbor_cells %= grid_dims else: neighbor_cells = jnp.clip(neighbor_cells, 0, grid_dims - 1) neighbor_hashes = jnp.dot(neighbor_cells, strides) # Build flat neighbor list and mask for wide vectorization flat_neighbors = cell_matrix[neighbor_hashes].reshape(N, -1) # (N, M*MAX_OCC) flat_active = jnp.broadcast_to( is_primary[:, None], (len(is_primary), MAX_OCC) ).reshape(-1) def per_particle(i, pos_pi, neighbors): def body_fun(idx, k): valid = flat_active[idx] & (padded_state.ID[k] != padded_state.ID[i]) res_f, res_t = system.force_model.force(i, k, padded_state, system) return valid * res_f, valid * (res_t + cross(pos_pi, res_f)) fs, ts = jax.vmap(body_fun)(jnp.arange(len(flat_active)), neighbors) return fs.sum(axis=0), ts.sum(axis=0) total_f, total_t = jax.vmap(per_particle)( iota, padded_pos_p[:N], flat_neighbors ) state.force = state.force.at[perm].add(total_f) state.torque = state.torque.at[perm].add(total_t) return state, system
[docs] @staticmethod @jax.jit def compute_potential_energy(state: "State", system: "System") -> jax.Array: collider = cast(MaterializedCellList, system.collider) N, MAX_OCC = state.N, collider.max_occupancy strides = collider.strides iota = jax.lax.iota(int, N) cell_indices = jnp.floor( (state.pos - system.domain.anchor) / collider.cell_size ).astype(int) if system.domain.periodic: cell_indices %= collider.grid_dims particle_hash, perm = jax.lax.sort( [jnp.dot(cell_indices, strides), iota], num_keys=1 ) is_first = jnp.concatenate( [jnp.array([True]), (particle_hash[1:] != particle_hash[:-1])] ) slots = ( jax.lax.associative_scan( lambda a, b: (jnp.where(b[1], b[0], a[0] + b[0]), a[1] | b[1]), (jnp.ones(N, int), is_first), )[0] - 1 ) cell_matrix = ( jnp.full(collider.num_cells * MAX_OCC, N, dtype=int) .at[particle_hash * MAX_OCC + jnp.where(slots < MAX_OCC, slots, 0)] .set(jnp.where(slots < MAX_OCC, iota, N)) .reshape(-1, MAX_OCC) ) def add_dummy(x): return jnp.concatenate([x[perm], jnp.expand_dims(x[perm[0]], 0)], axis=0) padded_state = jax.tree.map(add_dummy, state) padded_state = replace( padded_state, ID=padded_state.ID.at[N].set(-1), pos_c=padded_state.pos_c.at[N].set(jnp.array([1e10] * state.dim)), rad=padded_state.rad.at[N].set(0.0), ) neighbor_hashes = jnp.dot( (cell_indices[perm][:, None, :] + collider.neighbor_mask[None, :, :]) % collider.grid_dims if system.domain.periodic else jnp.clip( cell_indices[perm][:, None, :] + collider.neighbor_mask[None, :, :], 0, jnp.array(collider.grid_dims) - 1, ), strides, ) def per_particle(i, hash_list): def per_cell(h): occupants = cell_matrix[h] def body(k): valid = padded_state.ID[k] != padded_state.ID[i] return ( 0.5 * valid * system.force_model.energy(i, k, padded_state, system) ) return jax.vmap(body)(occupants).sum() return jax.vmap(per_cell)(hash_list).sum() energies = jax.vmap(per_particle)(iota, neighbor_hashes) return jnp.zeros(N).at[perm].set(energies)
[docs] @Collider.register("NeighborList") @jax.tree_util.register_dataclass @dataclass(slots=True) class NeighborList(Collider): r""" Neighbor List (Verlet List) collider following jax-md architectural patterns. This implementation uses a persistent neighbor list state, handles buffer overflows, and uses periodic-aware displacement for rebuild triggering. """ idx: jax.Array """Neighbor indices of shape (N, max_neighbors). -1 or N indicates padding.""" prev_pos: jax.Array """Positions at the time of the last neighbor list rebuild.""" did_buffer_overflow: jax.Array """Boolean scalar indicating if the neighbor list was too small to hold all pairs.""" update_threshold: float """Verlet skin distance. Rebuilds occur when max displacement > update_threshold / 2.""" max_neighbors: int = field(metadata={"static": True}) """Maximum number of neighbors stored per particle.""" # Implicit Cell List properties for efficient rebuilds cell_size: jax.Array neighbor_mask: jax.Array max_occupancy: int = field(metadata={"static": True}) grid_dims: Tuple[int, ...] = field(metadata={"static": True}) strides: jax.Array = field(metadata={"static": True})
[docs] @classmethod def Create( cls, state: "State", box_size: ArrayLike, max_neighbors: int = 64, update_threshold: float = 0.05, cell_size: Optional[ArrayLike] = None, max_occupancy: Optional[ArrayLike] = None, periodic: bool = True, ) -> Self: max_rad = jnp.max(state.rad) # Rebuild search radius must account for skin rebuild_radius = 2.0 * max_rad + update_threshold # Determine cell size for rebuild if cell_size is None: cell_size = rebuild_radius cell_size = jnp.asarray(cell_size, dtype=float) search_range = 1 # Standard for 1st-neighbor cell list rebuild r = jnp.arange(-search_range, search_range + 1, dtype=int) mesh = jnp.meshgrid(*([r] * state.dim), indexing="ij") neighbor_mask = jnp.stack([m.ravel() for m in mesh], axis=1) if max_occupancy is None: min_rad = jnp.min(state.rad) box_vol = cell_size**state.dim smallest_sphere_vol = ( (4.0 / 3.0) * jnp.pi * min_rad**3 / 0.9 if state.dim == 3 else jnp.pi * min_rad**2 ) max_occupancy = jnp.ceil(box_vol / smallest_sphere_vol) + 2 if periodic: grid_dims_val = jnp.floor(jnp.asarray(box_size) / cell_size).astype(int) else: grid_dims_val = jnp.ceil(jnp.asarray(box_size) / cell_size).astype(int) # Ensure grid dims are at least 1 grid_dims_val = jnp.maximum(1, grid_dims_val) grid_dims = tuple(map(int, grid_dims_val)) strides = jnp.concatenate( [jnp.array([1], dtype=int), jnp.cumprod(jnp.array(grid_dims[:-1]))] ) return cls( idx=jnp.full((state.N, max_neighbors), -1, dtype=int), prev_pos=state.pos, did_buffer_overflow=jnp.array(False), update_threshold=float(update_threshold), max_neighbors=int(max_neighbors), cell_size=cell_size, neighbor_mask=neighbor_mask.astype(int), max_occupancy=int(max_occupancy), grid_dims=grid_dims, strides=strides, )
@staticmethod def _rebuild_list( state: "State", system: "System", collider: "NeighborList" ) -> Tuple[jax.Array, jax.Array]: """Rebuilds the neighbor list indices using an implicit cell list traversal.""" N = state.N MAX_OCC = collider.max_occupancy MAX_NEIGH = collider.max_neighbors grid_dims = jnp.array(collider.grid_dims) strides = collider.strides iota = jax.lax.iota(int, N) # 1. Spatial Hash & Sort cell_indices = jnp.floor( (state.pos - system.domain.anchor) / collider.cell_size ).astype(int) if system.domain.periodic: cell_indices %= grid_dims particle_hash = jnp.dot(cell_indices, strides) p_hash_sorted, perm = jax.lax.sort([particle_hash, iota], num_keys=1) # 2. Neighborhood Probes neighbor_cell_indices = ( cell_indices[:, None, :] + collider.neighbor_mask[None, :, :] ) if system.domain.periodic: neighbor_cell_indices %= grid_dims neighbor_hashes = jnp.dot(neighbor_cell_indices, strides) # 3. Candidate Gathering start_indices = jnp.searchsorted(p_hash_sorted, neighbor_hashes, side="left") cand_offsets = jax.lax.iota(int, MAX_OCC) cand_ranks = start_indices[:, :, None] + cand_offsets[None, None, :] safe_ranks = jnp.minimum(cand_ranks, N - 1) valid_cand = (cand_ranks < N) & ( p_hash_sorted[safe_ranks] == neighbor_hashes[:, :, None] ) candidates = jnp.where(valid_cand, perm[safe_ranks], -1).reshape(N, -1) # 4. Filtering: pairwise distance check def filter_particle(i_idx, pos_i, rad_i, cand_list): pos_js = state.pos[jnp.where(cand_list >= 0, cand_list, 0)] rad_js = state.rad[jnp.where(cand_list >= 0, cand_list, 0)] rij = system.domain.displacement(pos_i, pos_js, system) dist_sq = jnp.sum(rij * rij, axis=-1) # Local cutoff for rebuild correctness including skin local_cutoff_sq = (rad_i + rad_js + collider.update_threshold) ** 2 survivors = ( (cand_list >= 0) & (cand_list != i_idx) & (dist_sq < local_cutoff_sq) ) num_survivors = jnp.sum(survivors) overflow = num_survivors > MAX_NEIGH # Efficient packing: move survivors to front packed = jnp.sort(jnp.where(survivors, cand_list, -1), descending=True)[ :MAX_NEIGH ] return packed, overflow idx, overflows = jax.vmap(filter_particle)( iota, state.pos, state.rad, candidates ) return idx, jnp.any(overflows)
[docs] @staticmethod @partial(jax.jit, donate_argnames=("state", "system")) @partial(jax.named_call, name="NeighborList.compute_force") def compute_force(state: "State", system: "System") -> Tuple["State", "System"]: collider = cast(NeighborList, system.collider) # 1. Decide Rebuild (periodic-aware displacement check) disp_since_last = system.domain.displacement( state.pos, collider.prev_pos, system ) max_disp_sq = jnp.max(jnp.sum(disp_since_last**2, axis=-1)) should_rebuild = max_disp_sq > (0.5 * collider.update_threshold) ** 2 def rebuild_fun(_): idx, overflow = NeighborList._rebuild_list(state, system, collider) return replace( collider, idx=idx, prev_pos=state.pos, did_buffer_overflow=overflow ) new_collider = jax.lax.cond( should_rebuild, rebuild_fun, lambda _: collider, None ) system = replace(system, collider=new_collider) # 2. Lab-frame rotated contact point vector pos_p_rotated = state.q.rotate(state.q, state.pos_p) # 3. Branch-free interaction with Dummy Padding N = state.N def add_dummy(x): return jnp.concatenate([x, jnp.expand_dims(x[0], 0)], axis=0) padded_state = jax.tree.map(add_dummy, state) padded_state = replace( padded_state, ID=padded_state.ID.at[N].set(-1), pos_c=padded_state.pos_c.at[N].set(jnp.array([1e10] * state.dim)), rad=padded_state.rad.at[N].set(0.0), ) # Map indices to dummy particle if -1 active_list = jnp.where(new_collider.idx < 0, N, new_collider.idx) def per_particle(i, neighbors, pos_pi): def interact(j): valid = padded_state.ID[j] != state.ID[i] res_f, res_t = system.force_model.force(i, j, padded_state, system) # Correct torque calculation with rotated local vector return valid * res_f, valid * (res_t + cross(pos_pi, res_f)) fs, ts = jax.vmap(interact)(neighbors) return fs.sum(axis=0), ts.sum(axis=0) total_f, total_t = jax.vmap(per_particle)( jax.lax.iota(int, N), active_list, pos_p_rotated ) state.force += total_f state.torque += total_t return state, system
[docs] @staticmethod @jax.jit def compute_potential_energy(state: "State", system: "System") -> jax.Array: collider = cast(NeighborList, system.collider) N = state.N def add_dummy(x): return jnp.concatenate([x, jnp.expand_dims(x[0], 0)], axis=0) padded_state = jax.tree.map(add_dummy, state) padded_state = replace( padded_state, ID=padded_state.ID.at[N].set(-1), pos_c=padded_state.pos_c.at[N].set(jnp.array([1e10] * state.dim)), rad=padded_state.rad.at[N].set(0.0), ) active_list = jnp.where(collider.idx < 0, N, collider.idx) def per_particle(i, neighbors): def interact(j): valid = padded_state.ID[j] != state.ID[i] return ( 0.5 * valid * system.force_model.energy(i, j, padded_state, system) ) return jax.vmap(interact)(neighbors).sum() return jax.vmap(per_particle)(jax.lax.iota(int, N), active_list)
__all__ = ["CellList", "DynamicCellList", "MaterializedCellList", "NeighborList"]