jaxdem.colliders#
Collision-detection interfaces and implementations.
Classes
|
The base interface for defining how contact detection and force computations are performed in a simulation. |
- class jaxdem.colliders.Collider[source]#
Bases:
Factory,ABCThe base interface for defining how contact detection and force computations are performed in a simulation.
Concrete subclasses of Collider implement the specific algorithms for calculating the interactions.
Notes
Self-interaction (i.e., calling the force/energy computation for i=j) is allowed, and the underlying force_model is responsible for correctly handling or ignoring this case.
Example
To define a custom collider, inherit from Collider, register it and implement its abstract methods:
>>> @Collider.register("CustomCollider") >>> @jax.tree_util.register_dataclass >>> @dataclass(slots=True) >>> class CustomCollider(Collider): ...
Then, instantiate it:
>>> jaxdem.Collider.create("CustomCollider", **custom_collider_kw)
- static compute_force(state: State, system: System) Tuple['State', 'System'][source][source]#
Abstract method to compute the total force acting on each particle in the simulation.
Implementations should calculate inter-particle forces and torques based on the current state and system configuration, then update the force and torque attributes of the state object with the resulting total force and torque for each particle.
- Parameters:
- Returns:
A tuple containing the updated State object (with computed forces) and the System object.
- Return type:
Note
This method donates state and system
- static compute_potential_energy(state: State, system: System) jax.Array[source][source]#
Abstract method to compute the total potential energy of the system.
Implementations should calculate the sum per particle of all potential energies present in the system based on the current state and system configuration.
- Parameters:
- Returns:
A scalar JAX array representing the total potential energy of each particle.
- Return type:
jax.Array
Example
>>> potential_energy = system.collider.compute_potential_energy(state, system) >>> print(f"Potential energy per particle: {potential_energy:.4f}") >>> print(potential_energy.shape") # (N, 1)
- class jaxdem.colliders.NaiveSimulator[source]#
Bases:
ColliderImplementation that computes forces and potential energies using a naive \(O(N^2)\) all-pairs interaction loop.
Notes
Due to its \(O(N^2)\) complexity, NaiveSimulator is suitable for simulations with a relatively small number of particles. For larger systems, a more efficient spatial partitioning collider should be used. However, this collider should be the fastest option for small systems (\(<1k-5k\) spheres depending on the GPU).
- static compute_force(state: State, system: System) Tuple['State', 'System'][source][source]#
Computes the total force acting on each particle using a naive \(O(N^2)\) all-pairs loop.
This method sums the force contributions from all particle pairs (i, j) as computed by the
system.force_modeland updates the particle forces.- Parameters:
- Returns:
A tuple containing the updated
Stateobject with computed forces and the unmodifiedSystemobject.- Return type:
Note
This method donates state and system
- static compute_potential_energy(state: State, system: System) jax.Array[source][source]#
Computes the potential energy associated with each particle using a naive \(O(N^2)\) all-pairs loop.
This method iterates over all particle pairs (i, j) and sums the potential energy contributions computed by the
system.force_model.- Parameters:
- Returns:
One-dimensional array containing the total potential energy contribution for each particle.
- Return type:
jax.Array
Note
This method donates state and system
- class jaxdem.colliders.CellList(neighbor_mask: Array, cell_size: Array, max_occupancy: int)[source]#
Bases:
ColliderImplicit cell-list (spatial hashing) collider.
This collider accelerates short-range pair interactions by partitioning the domain into a regular grid of cubic/square cells of side length
cell_size. Each particle is assigned to a cell, particles are sorted by cell hash, and interactions are evaluated only against particles in the same or neighboring cells given byneighbor_mask. The cell list is implicit because we never store per-cell particle lists explicitly; instead, we exploit the sorted hashes and fixedmax_occupancyto probe neighbors in-place.Complexity#
Time: \(O(N \log N)\) from sorting, plus \(O(N M K)\) for neighbor probing (M = number of neighbor cells, K =
max_occupancy).Memory: \(O(N)\).
Notes
max_occupancyis an upper bound on particles per cell. If a cell contains more than this many particles, some interactions might be missed (you should choosecell_sizeandmax_occupancyso this does not happen).
- neighbor_mask: jax.Array#
Integer offsets defining the neighbor stencil.
Shape is
(M, dim), where each row is a displacement in cell coordinates. Forsearch_range=1in 2D this is the 3×3 Moore neighborhood (M=9); in 3D this is the 3×3×3 neighborhood (M=27).
- cell_size: jax.Array#
Linear size of a grid cell (scalar).
- max_occupancy: int#
Maximum number of particles assumed to occupy a single cell.
The algorithm probes exactly
max_occupancyentries starting from the first particle in a neighbor cell. This should be set high enough that real cells rarely exceed it; otherwise contacts/energy will be undercounted.
- classmethod Create(state: State, cell_size: ArrayLike | None = None, search_range: ArrayLike | None = None, max_occupancy: ArrayLike | None = None) Self[source][source]#
Creates a CellList collider with robust defaults.
Defaults are chosen to avoid missing any contacts while keeping the neighbor stencil and assumed cell occupancy as small as possible given available information from
state. For this we assume no overlap between spheres.The cost of computing forces for one particle is determined by the number of neighboring cells to check and the occupancy of each cell. This cost can be estimated as:
\[\begin{split}\text{cost} = (2R + 1)^{dim} \cdot \text{max_occupancy} \\ \text{cost} = (2R + 1)^{dim} \cdot \left(\left\lceil \frac{L^{dim}}{V_{min}} \right\rceil +2 \right)\end{split}\]where \(R\) is the search radius, \(L\) is the cell size, and \(V_{min}\) is the volume of the smallest element. We assume \(V_{min}\) to be the volume of the smallest sphere, without accounting for the packing fraction, to provide a conservative upper bound. The search radius \(R\) is computed as:
\[R = \left\lceil \frac{2 r_{max}}{L} \right\rceil\]By default, we choose the options that yield the lowest computational cost: \(L = 2 \cdot r_{max}\) if \(\alpha < 2.5\), else \(L = r_{max}/2\).
The complexity of searching neighbors is \(O(N)\), where the choice of cell size and \(R\) attempts to minimize the constant factor. The constant factor grows with polydispersity (\(\alpha\)) as \(O(\alpha^{dim})\) with \(\alpha = r_{max}/r_{min}\). The cost for sorting and binary search remains \(O(N \log N)\).
- Parameters:
state (State) – Reference state used to determine spatial dimension and default parameters.
cell_size (float, optional) – Cell edge length. If None, defaults to a value optimized for the radius distribution.
search_range (int, optional) – Neighbor range in cell units. If None, the smallest safe value is computed such that \(\text{search\_range} \cdot L \geq \text{cutoff}\).
max_occupancy (int, optional) – Assumed maximum particles per cell. If None, estimated from a conservative packing upper bound using the smallest radius.
- Returns:
Configured collider instance.
- Return type:
- static compute_force(state: State, system: System) Tuple['State', 'System'][source][source]#
Computes the total force acting on each particle using an implicit cell list \(O(N log N)\). This method sums the force contributions from all particle pairs (i, j) as computed by the
system.force_modeland updates the particle forces.
- static compute_potential_energy(state: State, system: System) jax.Array[source][source]#
Computes the total force acting on each particle using an implicit cell list \(O(N log N)\). This method sums the force contributions from all particle pairs (i, j) as computed by the
system.force_modeland updates the particle forces.
- class jaxdem.colliders.DynamicCellList(neighbor_mask: Array, cell_size: Array, max_occupancy: int)[source]#
Bases:
ColliderImplicit cell-list (spatial hashing) collider using dynamic while-loops.
This collider accelerates short-range pair interactions by partitioning the domain into a regular grid. Unlike the standard CellList, this implementation uses a dynamic
jax.lax.while_loopto probe neighbor cells, which can be more efficient in systems with highly non-uniform particle distributions.Complexity#
Time: \(O(N \log N)\) from sorting, plus \(O(N M \langle K \rangle)\) for neighbor probing, where \(\langle K \rangle\) is the average cell occupancy.
Memory: \(O(N)\).
- neighbor_mask: jax.Array#
Integer offsets defining the neighbor stencil (M, dim).
- cell_size: jax.Array#
Linear size of a grid cell (scalar).
- max_occupancy: int#
Maximum number of particles assumed to occupy a single cell (loop safety limit).
- class jaxdem.colliders.MaterializedCellList(neighbor_mask: Array, cell_size: Array, max_occupancy: int, grid_dims: Tuple[int, ...], strides: Array, num_cells: int)[source]#
Bases:
ColliderUltra-Optimized Explicitly materialized cell list collider.
Uses “dummy particle” padding and Wide Vectorization to process entire neighbor blocks in single operations, maximizing GPU parallel throughput.
- neighbor_mask: jax.Array#
- cell_size: jax.Array#
- max_occupancy: int#
- grid_dims: Tuple[int, ...]#
- strides: jax.Array#
- num_cells: int#
- class jaxdem.colliders.NeighborList(idx: Array, prev_pos: Array, did_buffer_overflow: Array, update_threshold: float, max_neighbors: int, cell_size: Array, neighbor_mask: Array, max_occupancy: int, grid_dims: Tuple[int, ...], strides: Array)[source]#
Bases:
ColliderNeighbor List (Verlet List) collider following jax-md architectural patterns.
This implementation uses a persistent neighbor list state, handles buffer overflows, and uses periodic-aware displacement for rebuild triggering.
- idx: jax.Array#
Neighbor indices of shape (N, max_neighbors). -1 or N indicates padding.
- prev_pos: jax.Array#
Positions at the time of the last neighbor list rebuild.
- did_buffer_overflow: jax.Array#
Boolean scalar indicating if the neighbor list was too small to hold all pairs.
- update_threshold: float#
Verlet skin distance. Rebuilds occur when max displacement > update_threshold / 2.
- max_neighbors: int#
Maximum number of neighbors stored per particle.
- cell_size: jax.Array#
- neighbor_mask: jax.Array#
- max_occupancy: int#
- grid_dims: Tuple[int, ...]#
- strides: jax.Array#
Modules
Cell List \(O(N log N)\) collider implementation. |
|
Naive \(O(N^2)\) collider implementation. |
|
Sweep and prune \(O(N log N)\) collider implementation. |