jaxdem.colliders#
Collision-detection interfaces and implementations.
Classes
|
The base interface for defining how contact detection and force computations are performed in a simulation. |
- class jaxdem.colliders.Collider[source]#
Bases:
Factory,ABCThe base interface for defining how contact detection and force computations are performed in a simulation.
Concrete subclasses of Collider implement the specific algorithms for calculating the interactions.
Notes
Self-interaction (i.e., calling the force/energy computation for i=j) is allowed, and the underlying force_model is responsible for correctly handling or ignoring this case.
Example
To define a custom collider, inherit from Collider, register it and implement its abstract methods:
>>> @Collider.register("CustomCollider") >>> @jax.tree_util.register_dataclass >>> @dataclass(slots=True) >>> class CustomCollider(Collider): ...
Then, instantiate it:
>>> jaxdem.Collider.create("CustomCollider", **custom_collider_kw)
- static compute_force(state: State, system: System) Tuple[State, System][source][source]#
Abstract method to compute the total force acting on each particle in the simulation.
Implementations should calculate inter-particle forces and torques based on the current state and system configuration, then update the force and torque attributes of the state object with the resulting total force and torque for each particle.
- Parameters:
- Returns:
A tuple containing the updated State object (with computed forces) and the System object.
- Return type:
Note
This method donates state and system
- static compute_potential_energy(state: State, system: System) jax.Array[source][source]#
Abstract method to compute the total potential energy of the system.
Implementations should calculate the sum per particle of all potential energies present in the system based on the current state and system configuration.
- Parameters:
- Returns:
A scalar JAX array representing the total potential energy of each particle.
- Return type:
jax.Array
Example
>>> potential_energy = system.collider.compute_potential_energy(state, system) >>> print(f"Potential energy per particle: {potential_energy:.4f}") >>> print(potential_energy.shape) # (N,)
- static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) Tuple[State, System, jax.Array, jax.Array][source][source]#
Build a neighbor list for the current collider.
This is primarily used by neighbor-list-based algorithms and diagnostics. Implementations should match the cell-list semantics:
Returns a neighbor list of shape
(N, max_neighbors)padded with-1.Neighbor indices must refer to the returned (possibly sorted)
state.Also returns an
overflowboolean flag (True if any particle exceededmax_neighborsneighbors within the cutoff).
- class jaxdem.colliders.NaiveSimulator[source]#
Bases:
ColliderImplementation that computes forces and potential energies using a naive \(O(N^2)\) all-pairs interaction loop.
Notes
Due to its \(O(N^2)\) complexity, NaiveSimulator is suitable for simulations with a relatively small number of particles. For larger systems, a more efficient spatial partitioning collider should be used. However, this collider should be the fastest option for small systems (\(<1k-5k\) spheres depending on the GPU).
- static compute_potential_energy(state: State, system: System) jax.Array[source][source]#
Computes the potential energy associated with each particle using a naive \(O(N^2)\) all-pairs loop.
This method iterates over all particle pairs (i, j) and sums the potential energy contributions computed by the
system.force_model.- Parameters:
- Returns:
One-dimensional array containing the total potential energy contribution for each particle.
- Return type:
jax.Array
Note
This method donates state and system
- static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) Tuple[State, System, jax.Array, jax.Array][source][source]#
Naive O(N^2) neighbor list build.
Matches the cell-list neighbor-list API: returns (state, system, neighbor_list, overflow) where neighbor indices refer to the returned state (unsorted for naive).
- static compute_force(state: State, system: System) Tuple[State, System][source][source]#
Computes the total force acting on each particle using a naive \(O(N^2)\) all-pairs loop.
This method sums the force contributions from all particle pairs (i, j) as computed by the
system.force_modeland updates the particle forces.- Parameters:
- Returns:
A tuple containing the updated
Stateobject with computed forces and the unmodifiedSystemobject.- Return type:
Note
This method donates state and system
- class jaxdem.colliders.StaticCellList(neighbor_mask: Array, cell_size: Array, max_occupancy: int)[source]#
Bases:
ColliderImplicit cell-list (spatial hashing) collider.
This collider accelerates short-range pair interactions by partitioning the domain into a regular grid of cubic/square cells of side length
cell_size. Each particle is assigned to a cell, particles are sorted by cell hash, and interactions are evaluated only against particles in the same or neighboring cells given byneighbor_mask. The cell list is implicit because we never store per-cell particle lists explicitly; instead, we exploit the sorted hashes and fixedmax_occupancyto probe neighbors in-place.This collider is ideal for systems of spheres with minimum polydispersity and no dramatic overlaps. In this case, it might be even faster than the default cell list. However, it’s not recommended for systems with clumps, dramatic overlaps, as it might skip some contacts, or polydispersity, as it hinders the performance of this collider.
Complexity#
Time: \(O(N)\) - \(O(N \log N)\) from sorting, plus \(O(N M K)\) for neighbor probing (M = number of neighbor cells, K =
max_occupancy). The state is close to sorted every frame.Memory: \(O(N)\).
Notes
max_occupancyis an upper bound on particles per cell. If a cell contains more than this many particles, some interactions might be missed (you should choosecell_sizeandmax_occupancyso this does not happen).
- neighbor_mask: Array#
Integer offsets defining the neighbor stencil.
Shape is
(M, dim), where each row is a displacement in cell coordinates. Forsearch_range=1in 2D this is the 3×3 Moore neighborhood (M=9); in 3D this is the 3×3×3 neighborhood (M=27).
- cell_size: Array#
Linear size of a grid cell (scalar).
- max_occupancy: int#
Maximum number of particles assumed to occupy a single cell.
The algorithm probes exactly
max_occupancyentries starting from the first particle in a neighbor cell. This should be set high enough that real cells rarely exceed it; otherwise contacts/energy will be undercounted.
- classmethod Create(state: State, cell_size: ArrayLike | None = None, search_range: ArrayLike | None = None, max_occupancy: int | None = None) Self[source][source]#
Creates a StaticCellList collider with robust defaults.
Defaults are chosen to avoid missing any contacts while keeping the neighbor stencil and assumed cell occupancy as small as possible given available information from
state. For this we assume no overlap between spheres.The cost of computing forces for one particle is determined by the number of neighboring cells to check and the occupancy of each cell. This cost can be estimated as:
\[\begin{split}\text{cost} = (2R + 1)^{dim} \cdot \text{max_occupancy} \\ \text{cost} = (2R + 1)^{dim} \cdot \left(\left\lceil \frac{L^{dim}}{V_{min}} \right\rceil +2 \right)\end{split}\]where \(R\) is the search radius, \(L\) is the cell size, and \(V_{min}\) is the volume of the smallest element. We assume \(V_{min}\) to be the volume of the smallest sphere, without accounting for the packing fraction, to provide a conservative upper bound. The search radius \(R\) is computed as:
\[R = \left\lceil \frac{2 r_{max}}{L} \right\rceil\]By default, we choose the options that yield the lowest computational cost: \(L = 2 \cdot r_{max}\) if \(\alpha < 2.5\), else \(L = r_{max}/2\).
The complexity of searching neighbors is \(O(N)\), where the choice of cell size and \(R\) attempts to minimize the constant factor. The constant factor grows with polydispersity (\(\alpha\)) as \(O(\alpha^{dim})\) with \(\alpha = r_{max}/r_{min}\). The cost for sorting and binary search remains \(O(N \log N)\).
- Parameters:
state (State) – Reference state used to determine spatial dimension and default parameters.
cell_size (float, optional) – Cell edge length. If None, defaults to a value optimized for the radius distribution.
search_range (int, optional) – Neighbor range in cell units. If None, the smallest safe value is computed such that \(\text{search\_range} \cdot L \geq \text{cutoff}\).
max_occupancy (int, optional) – Assumed maximum particles per cell. If None, estimated from a conservative packing upper bound using the smallest radius.
- Returns:
Configured collider instance.
- Return type:
CellList
- static compute_force(state: State, system: System) Tuple[State, System][source][source]#
Computes the total force acting on each particle using an implicit cell list \(O(N log N)\). This method sums the force contributions from all particle pairs (i, j) as computed by the
system.force_modeland updates the particle forces.
- static compute_potential_energy(state: State, system: System) jax.Array[source][source]#
Computes the potential energy acting on each particle using an implicit cell list \(O(N log N)\). This method sums the energy contributions from all particle pairs (i, j) as computed by the
system.force_model.
- static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source][source]#
Computes the list of neighbors for each particle. The shape of the list is (N, max_neighbors). If a particle has less neighbors than max_neighbors, the list is padded with -1. The indices of the list correspond to the indices of the returned sorted state.
Note that no neighbors further than cell_size * (1 + search_range) (how many neighbors to check in the cells) can be found due to the nature of the cell list. If cutoff is greater than this value, the list might not return the expected list. Note that if a cell contains more spheres than those specified in max_occupancy, there might be missing neighbors.
- Parameters:
- Returns:
The sorted state, the system, the neighbor list, and a boolean flag for overflow.
- Return type:
- class jaxdem.colliders.DynamicCellList(neighbor_mask: Array, cell_size: Array)[source]#
Bases:
ColliderImplicit cell-list (spatial hashing) collider using dynamic while-loops.
This collider accelerates short-range pair interactions by partitioning the domain into a regular grid. Unlike the static cell list, this implementation uses a dynamic
jax.lax.while_loopto probe neighbor cells, which can be more efficient with polydisperse systems or low packing fractions. It’s also useful for systems that have a high occupancy per cell, for example, systems with clumps.Complexity#
Time: \(O(N)\) - \(O(N \log N)\) from sorting, plus \(O(N M \langle K \rangle)\) for neighbor probing, where \(\langle K \rangle\) is the average cell occupancy. The state is close to sorted every frame.
Memory: \(O(N)\).
- neighbor_mask: Array#
Integer offsets defining the neighbor stencil (M, dim).
- cell_size: Array#
Linear size of a grid cell (scalar).
- classmethod Create(state: State, cell_size: ArrayLike | None = None, search_range: ArrayLike | None = None, box_size: ArrayLike | None = None) Self[source][source]#
Creates a CellList collider with robust defaults.
Defaults are chosen to avoid missing any contacts while keeping the neighbor stencil and assumed cell occupancy as small as possible given available information from
state.The cost of computing forces for one particle is determined by the number of neighboring cells to check and the occupancy of each cell. This cost can be estimated as:
\[\begin{split}\text{cost} = (2R + 1)^{dim} \cdot \text{max_occupancy} \\ \text{cost} = (2R + 1)^{dim} \cdot \left(\left\lceil \frac{L^{dim}}{V_{min}} \right\rceil +2 \right)\end{split}\]where \(R\) is the search radius, \(L\) is the cell size, and \(V_{min}\) is the volume of the smallest element. We assume \(V_{min}\) to be the volume of the smallest sphere, without accounting for the packing fraction, to provide a conservative upper bound. The search radius \(R\) is computed as:
\[R = \left\lceil \frac{2 r_{max}}{L} \right\rceil\]By default, we choose the options that yield the lowest computational cost: \(L = 2 \cdot r_{max}\) if \(\alpha < 2.5\), else \(L = r_{max}/2\).
The complexity of searching neighbors is \(O(N)\), where the choice of cell size and \(R\) attempts to minimize the constant factor. The constant factor grows with polydispersity; however, the dynamic nature of the collider greatly minimizes polydispersity’s impact.
- Parameters:
state (State) – Reference state used to determine spatial dimension and default parameters.
cell_size (float, optional) – Cell edge length. If None, defaults to a value optimized for the radius distribution.
box_size (jax.Array, optional) – Size of the periodic box used to ensure there are at least 3 cells per axis. If None, these checks are ignored and will lead to errors if violated.
search_range (int, optional) – Neighbor range in cell units. If None, the smallest safe value is computed such that \(\text{search\_range} \cdot L \geq \text{cutoff}\).
- Returns:
Configured collider instance.
- Return type:
CellList
- static compute_force(state: State, system: System) Tuple[State, System][source][source]#
Computes the total force acting on each particle using an implicit cell list \(O(N log N)\). This method sums the force contributions from all particle pairs (i, j) as computed by the
system.force_modeland updates the particle forces.
- static compute_potential_energy(state: State, system: System) jax.Array[source][source]#
Computes the potential energy acting on each particle using an implicit cell list \(O(N log N)\). This method sums the energy contributions from all particle pairs (i, j) as computed by the
system.force_model.
- static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) Tuple[State, System, jax.Array, jax.Array][source][source]#
Computes the list of neighbors for each particle. The shape of the list is (N, max_neighbors). If a particle has less neighbors than max_neighbors, the list is padded with -1. The indices of the list correspond to the indices of the returned sorted state.
Note that no neighbors further than cell_size * (1 + search_range) (how many neighbors to check in the cells) can be found due to the nature of the cell list. If cutoff is greater than this value, the list might not return the expected list.
- Parameters:
- Returns:
The sorted state, the system, the neighbor list, and a boolean flag for overflow.
- Return type:
- class jaxdem.colliders.NeighborList(cell_list: DynamicCellList, neighbor_list: Array, old_pos: Array, n_build_times: int, cutoff: Array, skin: Array, overflow: Array, max_neighbors: int)[source]#
Bases:
ColliderVerlet Neighbor List collider.
This collider caches a list of neighbors for every particle. It only rebuilds the list when particles have moved more than half the ‘skin’ distance.
Performance Note: You must provide a non-zero skin (e.g., 0.1 * radius) for this collider to be efficient. If skin=0, it rebuilds every step.
- cell_list#
The underlying spatial partitioner used to build the list.
- Type:
- neighbor_list#
Shape (N, max_neighbors). Contains the IDs of neighboring particles. padded with -1.
- Type:
jax.Array
- old_pos#
Shape (N, dim). Positions of particles at the last build time.
- Type:
jax.Array
- n_build_times#
Counter for how many times the list has been rebuilt.
- Type:
int
- cutoff#
The interaction radius (force cutoff).
- Type:
float
- skin#
Buffer distance. The list is built with radius = cutoff + skin and rebuilt when max_displacement > skin / 2.
- Type:
float
- overflow#
Boolean flag indicating if the neighbor list overflowed during build.
- Type:
jax.Array
- max_neighbors#
Static buffer size for the neighbor list.
- Type:
int
- cell_list: DynamicCellList#
- neighbor_list: Array#
- old_pos: Array#
- n_build_times: int#
- cutoff: Array#
- skin: Array#
- overflow: Array#
- max_neighbors: int#
- classmethod Create(state: State, cutoff: float, box_size: jax.Array | None = None, skin: float = 0.05, max_neighbors: int | None = None, number_density: float = 1.0, safety_factor: float = 1.2, cell_size: float | None = None) Self[source][source]#
Creates a NeighborList collider.
- Parameters:
state (State) – Initial simulation state.
cutoff (float) – The physical interaction cutoff radius.
box_size (jax.Array, optional) – The size of the periodic box, if used.
skin (float, default 0.05) – The buffer distance. Must be > 0.0 for performance.
max_neighbors (int, optional) – Maximum neighbors to store per particle. If not provided, it is estimated from the number_density.
number_density (float, default 1.0) – Number density for the state used to calculate max_neighbors, if not provided. Assumed to be 1.0.
safety_factor (float, default 1.2) – Used to adjust the max_neighbors value calculated from number_density. Empirically obtained
cell_size (float, optional) – Override for the underlying cell list size.
- static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) Tuple[State, System, jax.Array, jax.Array][source][source]#
Return the cached neighbor list from this collider.
Notes
This method does not rebuild the neighbor list. It simply returns the last cached
neighbor_listandoverflowstored insystem.collider.The returned neighbor indices refer to the collider’s internal particle ordering at the time the cache was last updated (i.e., after the most recent rebuild inside
compute_force()).The
cutoffandmax_neighborsarguments are accepted for API compatibility but are currently ignored; the cache was built using this collider’s configuredcutoff + skinandmax_neighbors.
Modules
Cell List \(O(N log N)\) collider implementation. |
|
Naive \(O(N^2)\) collider implementation. |
|
Neighbor List Collider implementation. |
|
Sweep and prune \(O(N log N)\) collider implementation. |