jaxdem.colliders#
Collision-detection interfaces and implementations.
Functions
|
Pair mask shared by all colliders. |
Classes
|
The base interface for defining how contact detection and force computations are performed in a simulation. |
- class jaxdem.colliders.Collider#
Bases:
Factory,ABCThe base interface for defining how contact detection and force computations are performed in a simulation.
Concrete subclasses of Collider implement the specific algorithms for calculating the interactions.
Notes:#
Self-interaction (i.e., calling the force/energy computation for i=j) is allowed, and the underlying force_model is responsible for correctly handling or ignoring this case.
Example:#
To define a custom collider, inherit from Collider, register it and implement its abstract methods:
>>> @Collider.register("CustomCollider") >>> @jax.tree_util.register_dataclass >>> @dataclass(slots=True) >>> class CustomCollider(Collider): ...
Then, instantiate it:
>>> jaxdem.Collider.create("CustomCollider", **custom_collider_kw)
- static compute_force(state: State, system: System) tuple[State, System][source]#
Abstract method to compute the total force acting on each particle in the simulation.
Implementations should calculate inter-particle forces and torques based on the current state and system configuration, then update the force and torque attributes of the state object with the resulting total force and torque for each particle.
- static compute_potential_energy(state: State, system: System) jax.Array[source]#
Abstract method to compute the total potential energy of the system.
Implementations should calculate the total sum of all potential energies present in the system based on the current state and system configuration.
- Parameters:
- Returns:
A scalar JAX array representing the total potential energy of each particle.
- Return type:
jax.Array
Example
>>> potential_energy = system.collider.compute_potential_energy(state, system) >>> print(f"Potential energy per particle: {potential_energy:.4f}") >>> print(potential_energy.shape) # (N,)
- static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#
Build a neighbor list for the current collider.
This is primarily used by neighbor-list-based algorithms and diagnostics. Implementations should match the cell-list semantics:
Returns a neighbor list of shape
(N, max_neighbors)padded with-1.Neighbor indices must refer to the returned (possibly sorted)
state.Also returns an
overflowboolean flag (True if any particle exceededmax_neighborsneighbors within the cutoff).
- static create_cross_neighbor_list(pos_a: jax.Array, pos_b: jax.Array, system: System, cutoff: float, max_neighbors: int) tuple[jax.Array, jax.Array][source]#
Build a cross-neighbor list between two sets of positions.
For each point in
pos_a, finds all neighbors frompos_bwithin the givencutoffdistance. This is useful for coupling different particle systems or computing interactions between distinct sets of objects.The default implementation uses a naive \(O(N_A \times N_B)\) all-pairs search. Subclasses may override this with more efficient algorithms.
- Parameters:
pos_a (jax.Array) – Query positions, shape
(N_A, dim).pos_b (jax.Array) – Database positions, shape
(N_B, dim).system (System) – The configuration of the simulation (used for domain displacement).
cutoff (float) – Search radius.
max_neighbors (int) – Maximum number of neighbors to store per query point.
- Returns:
A tuple containing:
neighbor_list: Array of shape(N_A, max_neighbors)containing indices intopos_b, padded with-1.overflow: Boolean flag indicating if any query point exceededmax_neighborsneighbors within the cutoff.
- Return type:
Tuple[jax.Array, jax.Array]
- class jaxdem.colliders.DynamicCellList(neighbor_mask: Array, cell_size: Array)#
Bases:
ColliderImplicit cell-list (spatial hashing) collider using dynamic while-loops.
This collider accelerates short-range pair interactions by partitioning the domain into a regular grid of cubic/square cells of side length
cell_size. Each particle is assigned to a cell, particles are sorted by cell hash, and interactions are evaluated only against particles in the same or neighboring cells given byneighbor_mask.Unlike the static cell list, this implementation does not use a fixed
max_occupancyarray padding. Instead, it uses a dynamicjax.lax.while_loopto iterate over the exact number of particles present in each neighboring cell.The operation of this collider can be understood as the following nested loop:
for particle in particles: # parallel for hash in stencil(particle): # parallel while next_neighbor in cell(hash): # sequential ...
Because the innermost loop is evaluated sequentially, the computational cost is driven by the average cell occupancy rather than the maximum possible occupancy. This makes the total theoretical cost:
\[O(N \cdot \text{neighbor\_mask\_size} \cdot \langle K \rangle)\]where \(\langle K \rangle\) is the average cell occupancy. To understand how this scales, let’s analyze the cost components:
- Stencil size:
The stencil size depends on the ratio between the cell size (\(L\)) and the radius of the largest particle (\(r_{max}\)).
\[\text{neighbor\_mask\_size} = \left( 2\left\lceil \frac{2r_{max}}{L} \right\rceil + 1 \right)^{dim}\]
- Average occupancy:
The average number of particles that occupy a cell depends only on the cell volume and the macroscopic number density (\(\rho\)), completely independent of the smallest particle size.
\[\langle K \rangle = \rho L^{dim}\]
To express this in terms of the local volume fraction \(\phi\) (the ratio of volume actually occupied by particles to the total cell volume) and our normalized cell size \(L^\prime = L/r_{max}\), we use the average particle volume \(\langle V \rangle\):
\[\langle K \rangle = \phi \frac{L^{dim}}{\langle V \rangle} = \phi \frac{(L^\prime r_{max})^{dim}}{\langle V \rangle}\]Knowing that the volume of the largest particle is \(V_{max} = k_v r_{max}^{dim}\) (where \(k_v\) is the geometric volume factor, such as \(4\pi/3\) in 3D or \(\pi\) in 2D), we find the final theoretical cost:
\[\text{cost} \approx N \left( 2\left\lceil \frac{2}{L^\prime} \right\rceil + 1 \right)^{dim} \left( \frac{\phi}{k_v} \frac{V_{max}}{\langle V \rangle} (L^\prime)^{dim} \right)\]- The Polydispersity Advantage:
In the static cell list, cost scales with the ratio of the largest to smallest particle volume (\(V_{max}/V_{min} \propto \alpha^{dim}\), where \(\alpha = r_{max}/r_{min}\)). In this dynamic list, the cost scales with the ratio of the largest to the average particle volume (\(V_{max}/\langle V \rangle\)). Because adding tiny particles barely changes the average volume, the severe \(O(\alpha^{dim})\) padding penalty is significantly reduced or offset. The dynamic loop only iterates over particles that actually exist.
This collider is ideal for highly polydisperse systems, sparse systems (low packing fractions), or systems with rigid clumps that create massive local density spikes, as it completely avoids the memory bloat and wasted gather operations caused by array padding.
Complexity#
Time: \(O(N)\) - \(O(N \log N)\) from sorting, plus \(O(N \cdot M \cdot \langle K \rangle)\) for neighbor probing (M =
neighbor_mask_size, \(\langle K \rangle\) = average occupancy). The state is close to sorted every frame.Memory: \(O(N)\).
- neighbor_mask#
Integer offsets defining the neighbor stencil (M, dim).
- Type:
jax.Array
- cell_size#
Linear size of a grid cell (scalar).
- Type:
jax.Array
Notes
Batching with ``vmap``: If you use
jax.vmapto evaluate multiple simulation environments simultaneously, be aware of JAX’s SIMD execution model. Because the innermostwhileloop executes sequentially, the loop must continue running for all environments in the batch until the environment with the highest local cell occupancy finishes its iterations. Consequently, the computational cost of a batched execution is bottlenecked by the single worst-case occupancy across the entire batch.
- neighbor_mask: Array#
Integer offsets defining the neighbor stencil (M, dim).
- cell_size: Array#
Linear size of a grid cell (scalar).
- classmethod Create(state: State, cell_size: ArrayLike | None = None, search_range: ArrayLike | None = None, box_size: ArrayLike | None = None) Self[source]#
Creates a CellList collider with robust defaults.
Defaults are chosen to prevent missing contacts while keeping the neighbor stencil and assumed cell occupancy as small as possible given the available information from
state.Because this collider uses a dynamic while-loop, its optimal parameters differ slightly from the static implementation. The optimal cell size parameter, denoted here as the dimensionless ratio \(L^\prime = L/r_{max}\), is primarily driven by balancing the stencil size overhead against the sequential loop cost.
Standard density: The optimal cell size is typically \(L^\prime = 2\) (the diameter of the largest particle). This minimizes the search stencil to just the immediate 27 neighboring cells (in 3D), which is usually the most efficient balance for JAX compilations.
High polydispersity ( \(\alpha \gg 1\) ): Unlike the static cell list, the dynamic cell list handles high polydispersity gracefully. We generally maintain \(L^\prime = 2\) unless the size ratio is extreme, as shrinking the cell size drastically inflates the neighbor stencil, which harms the parallelized outer loops.
By default, if
cell_sizeis not provided, this method will infer an optimal value based on the radius distribution in the reference state.- Parameters:
state (State) – Reference state used to determine spatial dimension and default parameters.
cell_size (float, optional) – Cell edge length. If None, defaults to \(2 r_{max}\) for systems with low polydispersity (\(\alpha < 2.5\)), or \(0.5 r_{max}\) for highly polydisperse systems to balance stencil overhead.
search_range (int, optional) – Neighbor range in cell units. If None, the smallest safe value is computed such that \(\text{search\_range} \cdot L \geq \text{cutoff}\).
box_size (jax.Array, optional) – Size of the periodic box used to ensure there are at least
2 * search_range + 1cells per axis. If None, these checks are skipped.
- Returns:
Configured collider instance.
- Return type:
CellList
- static compute_force(state: State, system: System) tuple[State, System][source]#
Computes the total force acting on each particle using an implicit cell list \(O(N log N)\). This method sums the force contributions from all particle pairs (i, j) as computed by the
system.force_modeland updates the particle forces.
- static compute_potential_energy(state: State, system: System) jax.Array[source]#
Computes the potential energy acting on each particle using an implicit cell list \(O(N log N)\). This method sums the energy contributions from all particle pairs (i, j) as computed by the
system.force_model.
- static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#
Computes the list of neighbors for each particle. The shape of the list is (N, max_neighbors). If a particle has less neighbors than max_neighbors, the list is padded with -1. The indices of the list correspond to the indices of the returned sorted state.
Note that no neighbors further than cell_size * (1 + search_range) (how many neighbors to check in the cells) can be found due to the nature of the cell list. If cutoff is greater than this value, the list might not return the expected list.
- Parameters:
- Returns:
The sorted state, the system, the neighbor list, and a boolean flag for overflow.
- Return type:
- static create_cross_neighbor_list(pos_a: jax.Array, pos_b: jax.Array, system: System, cutoff: float, max_neighbors: int) tuple[jax.Array, jax.Array][source]#
Build a cross-neighbor list between two sets of positions using a dynamic cell list.
For each point in
pos_a, finds all neighbors frompos_bwithin the givencutoffdistance. Thepos_barray is hashed and sorted into cells, and the neighbor stencil of each query point inpos_ais used to probe the sortedpos_bhashes with a dynamicjax.lax.while_loop.No neighbors further than
cell_size * (1 + search_range)can be found due to the nature of the cell list.- Parameters:
pos_a (jax.Array) – Query positions, shape
(N_A, dim).pos_b (jax.Array) – Database positions, shape
(N_B, dim).system (System) – The configuration of the simulation.
cutoff (float) – Search radius.
max_neighbors (int) – Maximum number of neighbors to store per query point.
- Returns:
A tuple containing:
neighbor_list: Array of shape(N_A, max_neighbors)containing indices intopos_b, padded with-1.overflow: Boolean flag indicating if any query point exceededmax_neighborsneighbors within the cutoff.
- Return type:
Tuple[jax.Array, jax.Array]
- class jaxdem.colliders.NaiveSimulator#
Bases:
ColliderImplementation that computes forces and potential energies using a naive \(O(N^2)\) all-pairs interaction loop.
Notes
Due to its \(O(N^2)\) complexity, NaiveSimulator is suitable for simulations with a relatively small number of particles. For larger systems, a more efficient spatial partitioning collider should be used. However, this collider should be the fastest option for small systems (\(<1k-5k\) spheres depending on the GPU).
- static compute_potential_energy(state: State, system: System) jax.Array[source]#
Computes the potential energy associated with each particle using a naive \(O(N^2)\) all-pairs loop.
This method iterates over all particle pairs (i, j) and sums the potential energy contributions computed by the
system.force_model.
- static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#
Computes a neighbor list using a naive \(O(N^2)\) all-pairs search.
- Parameters:
- Returns:
A tuple containing: - state: The simulation state. - system: The simulation system. - neighbor_list: Array of shape (N, max_neighbors) containing neighbor indices. - overflow: Boolean flag indicating if any particle exceeded
max_neighbors.- Return type:
- static compute_force(state: State, system: System) tuple[State, System][source]#
Computes the total force acting on each particle using a naive \(O(N^2)\) all-pairs loop.
This method sums the force contributions from all particle pairs (i, j) as computed by the
system.force_modeland updates the particle forces.
- class jaxdem.colliders.NeighborList(cell_list: Collider, neighbor_list: Array, old_pos: Array, n_build_times: Array, cutoff: Array, skin: Array, overflow: Array, max_neighbors: int)#
Bases:
ColliderImplementation of a Verlet neighbor list collider.
This collider caches a list of neighbors for every particle, significantly reducing the number of distance calculations required at each time step. The list is only rebuilt when the maximum displacement of any particle exceeds half of the specified
skindistance.Complexity#
Time: \(O(N)\) between rebuilds. Rebuild complexity depends on the underlying
cell_listcollider (typically \(O(N \log N)\)).
Notes
You must provide a non-zero
skin(e.g., 0.1 * radius) for this collider to be efficient. Ifskin = 0, the list is rebuilt every step, which is computationally expensive.- neighbor_list: Array#
Shape (N, max_neighbors). Contains the IDs of neighboring particles, padded with -1.
- old_pos: Array#
Shape (N, dim). Positions of particles at the last build time.
- n_build_times: Array#
Counter for how many times the list has been rebuilt.
- cutoff: Array#
The interaction radius (force cutoff).
- skin: Array#
Buffer distance. The list is built with
radius = cutoff + skinand rebuilt whenmax_displacement > skin / 2.
- overflow: Array#
Boolean flag indicating if the neighbor list overflowed during build.
- max_neighbors: int#
Static buffer size for the neighbor list.
- classmethod Create(state: State, cutoff: float, skin: float = 0.05, max_neighbors: int | None = None, number_density: float = 1.0, safety_factor: float = 1.2, secondary_collider_type: str = 'CellList', secondary_collider_kw: dict[str, Any] | None = None) Self[source]#
Creates a NeighborList collider.
- Parameters:
state (State) – The initial simulation state used to determine system dimensions and particle count.
cutoff (float) – The physical interaction cutoff radius.
skin (float, default 0.05) – The buffer distance added to the cutoff for the neighbor list. Must be > 0.0 for performance.
max_neighbors (int, optional) – Maximum number of neighbors to store per particle. If not provided, it is estimated from the
number_density.number_density (float, default 1.0) – Number density of the system used to estimate
max_neighborsif not explicitly provided.safety_factor (float, default 1.2) – Multiplier applied to the estimated number of neighbors to account for fluctuations in local density.
secondary_collider_type (str, default "CellList") – Registered collider type used internally to build the neighbor lists.
secondary_collider_kw (dict[str, Any], optional) – Keyword arguments passed to the constructor of the internal collider. If None,
cell_sizeis set tocutoff + skin.
- Returns:
A configured NeighborList collider instance.
- Return type:
- static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#
Returns the cached neighbor list from this collider.
This method does not rebuild the neighbor list. It returns the current cached
neighbor_listandoverflowflag stored in the collider.- Parameters:
- Returns:
A tuple containing:
state: The simulation state.
system: The simulation system.
neighbor_list: The cached neighbor list of shape (N, max_neighbors).
overflow: Boolean flag indicating if the list overflowed during the last build.
- Return type:
Notes
The returned neighbor indices refer to the internal particle ordering established during the most recent rebuild inside
compute_force.
- static compute_force(state: State, system: System) tuple[State, System][source]#
Computes total forces acting on each particle, rebuilding the neighbor list if necessary.
This method checks if any particle has moved enough to trigger a rebuild (displacement > skin/2). If so, it invokes the internal spatial partitioner to refresh the neighbor list. It then sums force contributions using the cached list.
- static compute_potential_energy(state: State, system: System) jax.Array[source]#
Computes the potential energy associated with each particle using the cached neighbor list.
This method iterates over the cached neighbors for each particle and sums the potential energy contributions computed by the
system.force_model.
- static create_cross_neighbor_list(pos_a: jax.Array, pos_b: jax.Array, system: System, cutoff: float, max_neighbors: int) tuple[jax.Array, jax.Array][source]#
Build a cross-neighbor list between two sets of positions.
Delegates to the internal
cell_listcollider’screate_cross_neighbor_listmethod.- Parameters:
pos_a (jax.Array) – Query positions, shape
(N_A, dim).pos_b (jax.Array) – Database positions, shape
(N_B, dim).system (System) – The configuration of the simulation (used for domain displacement).
cutoff (float) – Search radius.
max_neighbors (int) – Maximum number of neighbors to store per query point.
- Returns:
A tuple containing:
neighbor_list: Array of shape(N_A, max_neighbors)containing indices intopos_b, padded with-1.overflow: Boolean flag indicating if any query point exceededmax_neighborsneighbors within the cutoff.
- Return type:
Tuple[jax.Array, jax.Array]
- class jaxdem.colliders.StaticCellList(neighbor_mask: Array, cell_size: Array, max_occupancy: int)#
Bases:
ColliderImplicit cell-list (spatial hashing) collider.
This collider accelerates short-range pair interactions by partitioning the domain into a regular grid of cubic/square cells of side length
cell_size. Each particle is assigned to a cell, particles are sorted by cell hash, and interactions are evaluated only against particles in the same or neighboring cells given byneighbor_mask. The cell list is implicit because we never store per-cell particle lists explicitly; instead, we exploit the sorted hashes and fixedmax_occupancyto probe neighbors in-place.The operation of this collider can be understood as the following nested loop:
for particle in particles: for hash in stencil(particle): for i in range(max_occupancy): ...
The loop is flattened and everything happens in parallel. We can do this because the stencil size and
max_occupancyare static values. This means that the cost of the calculations done with this cell list is:\[O(N \cdot \text{neighbor\_mask\_size} \cdot \text{max\_occupancy})\]Plus the cost of hashing \(O(N)\) and sorting \(O(N \log N)\). As we sort every time step, the system remains mainly sorted, reducing the practical sorting complexity to something closer to \(O(N)\). This fixed
max_occupancymakes the cell list ideal for certain types of systems but very bad for others. To understand the difference, let’s analyze the cost components:- Stencil size:
The stencil size depends on the ratio between the cell size (\(L\)) and the radius of the largest particle (\(r_{max}\)).
\[\text{neighbor\_mask\_size} = \left( 2\left\lceil \frac{2r_{max}}{L} \right\rceil + 1 \right)^{dim}\]
- Max occupancy:
The maximum number of particles that can occupy a cell depends on the cell volume, the density (represented as local volume fraction \(\phi\)), and the volume of the smallest particle (\(V_{min}\)). To prevent missing contacts during local density fluctuations, we estimate the expected occupancy (\(\lambda\)) and add a safety margin of 3 standard deviations:
\[\text{max\_occupancy} = \left\lceil \lambda + 3\sqrt{\lambda} \right\rceil \quad \text{where} \quad \lambda = \phi \frac{L^{dim}}{V_{min}}\]Here, \(\phi\) is the local volume fraction, representing the ratio of the volume actually occupied by particles to the total volume of the cell. For dense packings of spheres, \(\phi \approx 0.74\), but in systems with highly overlapping internal spheres (like rigid clumps), \(\phi\) can be much larger.
Putting this together and expressing the cell size in units of the maximum radius \(L^\prime = L/r_{max}\), we find the total theoretical cost to be:
\[\text{cost} \approx N \left( 2\left\lceil \frac{2}{L^\prime} \right\rceil + 1 \right)^{dim} \left\lceil \phi \frac{L^{dim}}{V_{min}} + 3\sqrt{\phi \frac{L^{dim}}{V_{min}}} \right\rceil\]Then, if we define polydispersity as the ratio between the largest and smallest particle \(\alpha = r_{max}/r_{min}\), we can express \(\lambda\) solely in terms of these dimensionless parameters. Knowing that the volume of the smallest particle is \(V_{min} = k_v r_{min}^{dim}\) (where \(k_v\) is the geometric volume factor, such as \(4\pi/3\) in 3D or \(\pi\) in 2D), we can write \(\lambda\) as:
\[\lambda = \frac{\phi}{k_v} (\alpha L^\prime)^{dim}\]Substituting this back, we find the final cost function:
\[\text{cost} \approx N \left( 2\left\lceil \frac{2}{L^\prime} \right\rceil + 1 \right)^{dim} \left\lceil \frac{\phi}{k_v} (\alpha L^\prime)^{dim} + 3\sqrt{\frac{\phi}{k_v}} (\alpha L^\prime)^{dim/2} \right\rceil\]- The Polydispersity Penalty:
Because we must safely bound the maximum possible number of particles in a cell, we base \(\lambda\) on the smallest particle. As shown in the equation above, the required array padding grows dramatically as \(O(\alpha^{dim})\), causing performance to degrade in highly polydisperse systems.
This collider is ideal for systems of spheres with minimum polydispersity and no dramatic overlaps. In those cases, it might be even faster than the dynamic cell list. However, it’s not recommended for systems with clumps (where internal overlaps cause extreme local \(\phi\)) or high polydispersity, as both drastically inflate the required
max_occupancypadding.Complexity#
Time: \(O(N)\) - \(O(N \log N)\) from sorting, plus \(O(N \cdot M \cdot K)\) for neighbor probing (M =
neighbor_mask_size, K =max_occupancy). The state is close to sorted every frame.Memory: \(O(N)\).
- cell_size#
Linear size of a grid cell (scalar).
- Type:
jax.Array
- max_occupancy#
Maximum number of particles assumed to occupy a single cell. The algorithm probes exactly
max_occupancyentries starting from the first particle in a neighbor cell. This should be set high enough that real cells rarely exceed it; otherwise, contacts and energy will be undercounted.- Type:
int
Notes
max_occupancyis an upper bound on particles per cell. If a cell contains more than this many particles, some interactions might be missed (you should carefully choosecell_sizeand ensure your local density does not exceed your expectedmax_occupancy).
- neighbor_mask: Array#
Integer offsets defining the neighbor stencil.
Shape is
(M, dim), where each row is a displacement in cell coordinates. Forsearch_range=1in 2D this is the 3×3 Moore neighborhood (M=9); in 3D this is the 3×3×3 neighborhood (M=27).
- cell_size: Array#
Linear size of a grid cell (scalar).
- max_occupancy: int#
Maximum number of particles assumed to occupy a single cell.
The algorithm probes exactly
max_occupancyentries starting from the first particle in a neighbor cell. This should be set high enough that real cells rarely exceed it; otherwise contacts/energy will be undercounted.
- classmethod Create(state: State, cell_size: ArrayLike | None = None, search_range: ArrayLike | None = None, box_size: ArrayLike | None = None, max_occupancy: int | None = None) Self[source]#
Creates a StaticCellList collider with robust defaults.
Defaults are chosen to prevent missing contacts while keeping the neighbor stencil and assumed cell occupancy as small as possible given the available information from
state.The optimal cell size parameter, denoted here as the dimensionless ratio \(L^\prime = L/r_{max}\), depends heavily on the system’s volume fraction (\(\phi\)) and polydispersity (\(\alpha = r_{max}/r_{min}\)):
Standard density ( \(\phi \le 1\) ): The optimal cell size is typically \(L^\prime = 2\) (the diameter of the largest particle). This minimizes the search stencil to just the immediate neighboring cells.
Extreme local density ( \(\phi \gg 1\) ): For systems with heavy internal overlaps (like rigid clumps), the optimal cell size shrinks. Values like \(L^\prime = 0.5\) or \(L^\prime = 0.25\) often yield better performance by reducing the massive array padding penalty, even at the cost of a larger search stencil.
High polydispersity ( \(\alpha \gg 1\) ): High polydispersity severely degrades performance regardless of density, because the fixed occupancy arrays must always be padded to accommodate the volume of the smallest particles.
By default, if
cell_sizeormax_occupancyare not provided, this method infers optimal safe values based on the radius distribution in the reference state, assuming a maximum local volume fraction of \(\phi = 1\). If your system contains clumps with internal overlaps where local \(\phi > 1\), you must override these defaults.- Parameters:
state (State) – Reference state used to determine spatial dimension and default parameters.
cell_size (float, optional) – Cell edge length. If None, defaults to \(2 r_{max}\) for systems with low polydispersity (\(\alpha < 2.5\)), or \(0.5 r_{max}\) for highly polydisperse systems to mitigate exponential array padding costs.
search_range (int, optional) – Neighbor range in cell units. If None, the smallest safe value is computed such that \(\text{search\_range} \cdot L \geq \text{cutoff}\).
box_size (jax.Array, optional) – Size of the periodic box used to ensure there are at least
2 * search_range + 1cells per axis. If None, these checks are skipped.max_occupancy (int, optional) – Assumed maximum particles per cell. If None, estimated using the statistical model: \(\lambda + 3\sqrt{\lambda}\), assuming a worst-case standard granular density of \(\phi = 1\) and the volume of the smallest particle.
- Returns:
Configured collider instance.
- Return type:
CellList
- static compute_force(state: State, system: System) tuple[State, System][source]#
Computes the total force acting on each particle using an implicit cell list \(O(N log N)\). This method sums the force contributions from all particle pairs (i, j) as computed by the
system.force_modeland updates the particle forces.
- static compute_potential_energy(state: State, system: System) jax.Array[source]#
Computes the potential energy acting on each particle using an implicit cell list \(O(N log N)\). This method sums the energy contributions from all particle pairs (i, j) as computed by the
system.force_model.
- static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#
Computes the list of neighbors for each particle. The shape of the list is (N, max_neighbors). If a particle has less neighbors than max_neighbors, the list is padded with -1. The indices of the list correspond to the indices of the returned sorted state.
Note that no neighbors further than cell_size * (1 + search_range) (how many neighbors to check in the cells) can be found due to the nature of the cell list. If cutoff is greater than this value, the list might not return the expected list. Note that if a cell contains more spheres than those specified in max_occupancy, there might be missing neighbors.
- Parameters:
- Returns:
The sorted state, the system, the neighbor list, and a boolean flag for overflow.
- Return type:
- static create_cross_neighbor_list(pos_a: jax.Array, pos_b: jax.Array, system: System, cutoff: float, max_neighbors: int) tuple[jax.Array, jax.Array][source]#
Build a cross-neighbor list between two sets of positions using an implicit cell list.
For each point in
pos_a, finds all neighbors frompos_bwithin the givencutoffdistance. Thepos_barray is hashed and sorted into cells, and the neighbor stencil of each query point inpos_ais used to probe the sortedpos_bhashes with a fixed-size unrolled loop.No neighbors further than
cell_size * (1 + search_range)can be found due to the nature of the cell list. If a cell contains more particles thanmax_occupancy, some neighbors may be missed.- Parameters:
pos_a (jax.Array) – Query positions, shape
(N_A, dim).pos_b (jax.Array) – Database positions, shape
(N_B, dim).system (System) – The configuration of the simulation.
cutoff (float) – Search radius.
max_neighbors (int) – Maximum number of neighbors to store per query point.
- Returns:
A tuple containing:
neighbor_list: Array of shape(N_A, max_neighbors)containing indices intopos_b, padded with-1.overflow: Boolean flag indicating if any query point exceededmax_neighborsneighbors within the cutoff.
- Return type:
Tuple[jax.Array, jax.Array]
- jaxdem.colliders.valid_interaction_mask(clump_i: Array, clump_j: Array, bond_i: Array, bond_j: Array, interact_same_bond_id: Array) Array[source]#
Pair mask shared by all colliders.
Interactions are always disabled for particles in the same clump. Interactions for particles with equal
bond_idare controlled byinteract_same_bond_id.
Modules
Cell List \(O(N log N)\) collider implementation. |
|
Naive \(O(N^2)\) collider implementation. |
|
Neighbor List Collider implementation. |
|
Sweep and prune \(O(N log N)\) collider implementation. |