jaxdem.colliders#

Collision-detection interfaces and implementations.

Functions

valid_interaction_mask(clump_i, clump_j, ...)

Pair mask shared by all colliders.

Classes

Collider()

The base interface for defining how contact detection and force computations are performed in a simulation.

class jaxdem.colliders.Collider#

Bases: Factory, ABC

The base interface for defining how contact detection and force computations are performed in a simulation.

Concrete subclasses of Collider implement the specific algorithms for calculating the interactions.

Notes:#

Self-interaction (i.e., calling the force/energy computation for i=j) is allowed, and the underlying force_model is responsible for correctly handling or ignoring this case.

Example:#

To define a custom collider, inherit from Collider, register it and implement its abstract methods:

>>> @Collider.register("CustomCollider")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class CustomCollider(Collider):
        ...

Then, instantiate it:

>>> jaxdem.Collider.create("CustomCollider", **custom_collider_kw)
static compute_force(state: State, system: System) tuple[State, System][source]#

Abstract method to compute the total force acting on each particle in the simulation.

Implementations should calculate inter-particle forces and torques based on the current state and system configuration, then update the force and torque attributes of the state object with the resulting total force and torque for each particle.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object (with computed forces) and the System object.

Return type:

Tuple[State, System]

static compute_potential_energy(state: State, system: System) jax.Array[source]#

Abstract method to compute the total potential energy of the system.

Implementations should calculate the total sum of all potential energies present in the system based on the current state and system configuration.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A scalar JAX array representing the total potential energy of each particle.

Return type:

jax.Array

Example

>>> potential_energy = system.collider.compute_potential_energy(state, system)
>>> print(f"Potential energy per particle: {potential_energy:.4f}")
>>> print(potential_energy.shape)  # (N,)
static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#

Build a neighbor list for the current collider.

This is primarily used by neighbor-list-based algorithms and diagnostics. Implementations should match the cell-list semantics:

  • Returns a neighbor list of shape (N, max_neighbors) padded with -1.

  • Neighbor indices must refer to the returned (possibly sorted) state.

  • Also returns an overflow boolean flag (True if any particle exceeded max_neighbors neighbors within the cutoff).

static create_cross_neighbor_list(pos_a: jax.Array, pos_b: jax.Array, system: System, cutoff: float, max_neighbors: int) tuple[jax.Array, jax.Array][source]#

Build a cross-neighbor list between two sets of positions.

For each point in pos_a, finds all neighbors from pos_b within the given cutoff distance. This is useful for coupling different particle systems or computing interactions between distinct sets of objects.

The default implementation uses a naive \(O(N_A \times N_B)\) all-pairs search. Subclasses may override this with more efficient algorithms.

Parameters:
  • pos_a (jax.Array) – Query positions, shape (N_A, dim).

  • pos_b (jax.Array) – Database positions, shape (N_B, dim).

  • system (System) – The configuration of the simulation (used for domain displacement).

  • cutoff (float) – Search radius.

  • max_neighbors (int) – Maximum number of neighbors to store per query point.

Returns:

A tuple containing:

  • neighbor_list: Array of shape (N_A, max_neighbors) containing indices into pos_b, padded with -1.

  • overflow: Boolean flag indicating if any query point exceeded max_neighbors neighbors within the cutoff.

Return type:

Tuple[jax.Array, jax.Array]

class jaxdem.colliders.DynamicCellList(neighbor_mask: Array, cell_size: Array)#

Bases: Collider

Implicit cell-list (spatial hashing) collider using dynamic while-loops.

This collider accelerates short-range pair interactions by partitioning the domain into a regular grid of cubic/square cells of side length cell_size. Each particle is assigned to a cell, particles are sorted by cell hash, and interactions are evaluated only against particles in the same or neighboring cells given by neighbor_mask.

Unlike the static cell list, this implementation does not use a fixed max_occupancy array padding. Instead, it uses a dynamic jax.lax.while_loop to iterate over the exact number of particles present in each neighboring cell.

The operation of this collider can be understood as the following nested loop:

for particle in particles: # parallel
    for hash in stencil(particle): # parallel
        while next_neighbor in cell(hash): # sequential
            ...

Because the innermost loop is evaluated sequentially, the computational cost is driven by the average cell occupancy rather than the maximum possible occupancy. This makes the total theoretical cost:

\[O(N \cdot \text{neighbor\_mask\_size} \cdot \langle K \rangle)\]

where \(\langle K \rangle\) is the average cell occupancy. To understand how this scales, let’s analyze the cost components:

  • Stencil size:

    The stencil size depends on the ratio between the cell size (\(L\)) and the radius of the largest particle (\(r_{max}\)).

    \[\text{neighbor\_mask\_size} = \left( 2\left\lceil \frac{2r_{max}}{L} \right\rceil + 1 \right)^{dim}\]
  • Average occupancy:

    The average number of particles that occupy a cell depends only on the cell volume and the macroscopic number density (\(\rho\)), completely independent of the smallest particle size.

    \[\langle K \rangle = \rho L^{dim}\]

To express this in terms of the local volume fraction \(\phi\) (the ratio of volume actually occupied by particles to the total cell volume) and our normalized cell size \(L^\prime = L/r_{max}\), we use the average particle volume \(\langle V \rangle\):

\[\langle K \rangle = \phi \frac{L^{dim}}{\langle V \rangle} = \phi \frac{(L^\prime r_{max})^{dim}}{\langle V \rangle}\]

Knowing that the volume of the largest particle is \(V_{max} = k_v r_{max}^{dim}\) (where \(k_v\) is the geometric volume factor, such as \(4\pi/3\) in 3D or \(\pi\) in 2D), we find the final theoretical cost:

\[\text{cost} \approx N \left( 2\left\lceil \frac{2}{L^\prime} \right\rceil + 1 \right)^{dim} \left( \frac{\phi}{k_v} \frac{V_{max}}{\langle V \rangle} (L^\prime)^{dim} \right)\]
  • The Polydispersity Advantage:

    In the static cell list, cost scales with the ratio of the largest to smallest particle volume (\(V_{max}/V_{min} \propto \alpha^{dim}\), where \(\alpha = r_{max}/r_{min}\)). In this dynamic list, the cost scales with the ratio of the largest to the average particle volume (\(V_{max}/\langle V \rangle\)). Because adding tiny particles barely changes the average volume, the severe \(O(\alpha^{dim})\) padding penalty is significantly reduced or offset. The dynamic loop only iterates over particles that actually exist.

This collider is ideal for highly polydisperse systems, sparse systems (low packing fractions), or systems with rigid clumps that create massive local density spikes, as it completely avoids the memory bloat and wasted gather operations caused by array padding.

Complexity#

  • Time: \(O(N)\) - \(O(N \log N)\) from sorting, plus \(O(N \cdot M \cdot \langle K \rangle)\) for neighbor probing (M = neighbor_mask_size, \(\langle K \rangle\) = average occupancy). The state is close to sorted every frame.

  • Memory: \(O(N)\).

neighbor_mask#

Integer offsets defining the neighbor stencil (M, dim).

Type:

jax.Array

cell_size#

Linear size of a grid cell (scalar).

Type:

jax.Array

Notes

  • Batching with ``vmap``: If you use jax.vmap to evaluate multiple simulation environments simultaneously, be aware of JAX’s SIMD execution model. Because the innermost while loop executes sequentially, the loop must continue running for all environments in the batch until the environment with the highest local cell occupancy finishes its iterations. Consequently, the computational cost of a batched execution is bottlenecked by the single worst-case occupancy across the entire batch.

neighbor_mask: Array#

Integer offsets defining the neighbor stencil (M, dim).

cell_size: Array#

Linear size of a grid cell (scalar).

classmethod Create(state: State, cell_size: ArrayLike | None = None, search_range: ArrayLike | None = None, box_size: ArrayLike | None = None) Self[source]#

Creates a CellList collider with robust defaults.

Defaults are chosen to prevent missing contacts while keeping the neighbor stencil and assumed cell occupancy as small as possible given the available information from state.

Because this collider uses a dynamic while-loop, its optimal parameters differ slightly from the static implementation. The optimal cell size parameter, denoted here as the dimensionless ratio \(L^\prime = L/r_{max}\), is primarily driven by balancing the stencil size overhead against the sequential loop cost.

  • Standard density: The optimal cell size is typically \(L^\prime = 2\) (the diameter of the largest particle). This minimizes the search stencil to just the immediate 27 neighboring cells (in 3D), which is usually the most efficient balance for JAX compilations.

  • High polydispersity ( \(\alpha \gg 1\) ): Unlike the static cell list, the dynamic cell list handles high polydispersity gracefully. We generally maintain \(L^\prime = 2\) unless the size ratio is extreme, as shrinking the cell size drastically inflates the neighbor stencil, which harms the parallelized outer loops.

By default, if cell_size is not provided, this method will infer an optimal value based on the radius distribution in the reference state.

Parameters:
  • state (State) – Reference state used to determine spatial dimension and default parameters.

  • cell_size (float, optional) – Cell edge length. If None, defaults to \(2 r_{max}\) for systems with low polydispersity (\(\alpha < 2.5\)), or \(0.5 r_{max}\) for highly polydisperse systems to balance stencil overhead.

  • search_range (int, optional) – Neighbor range in cell units. If None, the smallest safe value is computed such that \(\text{search\_range} \cdot L \geq \text{cutoff}\).

  • box_size (jax.Array, optional) – Size of the periodic box used to ensure there are at least 2 * search_range + 1 cells per axis. If None, these checks are skipped.

Returns:

Configured collider instance.

Return type:

CellList

static compute_force(state: State, system: System) tuple[State, System][source]#

Computes the total force acting on each particle using an implicit cell list \(O(N log N)\). This method sums the force contributions from all particle pairs (i, j) as computed by the system.force_model and updates the particle forces.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object with computed forces and the unmodified System object.

Return type:

Tuple[State, System]

static compute_potential_energy(state: State, system: System) jax.Array[source]#

Computes the potential energy acting on each particle using an implicit cell list \(O(N log N)\). This method sums the energy contributions from all particle pairs (i, j) as computed by the system.force_model.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

Scalar containing the total potential energy of the system.

Return type:

jax.Array

static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#

Computes the list of neighbors for each particle. The shape of the list is (N, max_neighbors). If a particle has less neighbors than max_neighbors, the list is padded with -1. The indices of the list correspond to the indices of the returned sorted state.

Note that no neighbors further than cell_size * (1 + search_range) (how many neighbors to check in the cells) can be found due to the nature of the cell list. If cutoff is greater than this value, the list might not return the expected list.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

  • cutoff (float) – Search radius

  • max_neighbors (int) – Maximum number of neighbors to store per particle.

Returns:

The sorted state, the system, the neighbor list, and a boolean flag for overflow.

Return type:

tuple[State, System, jax.Array, jax.Array]

static create_cross_neighbor_list(pos_a: jax.Array, pos_b: jax.Array, system: System, cutoff: float, max_neighbors: int) tuple[jax.Array, jax.Array][source]#

Build a cross-neighbor list between two sets of positions using a dynamic cell list.

For each point in pos_a, finds all neighbors from pos_b within the given cutoff distance. The pos_b array is hashed and sorted into cells, and the neighbor stencil of each query point in pos_a is used to probe the sorted pos_b hashes with a dynamic jax.lax.while_loop.

No neighbors further than cell_size * (1 + search_range) can be found due to the nature of the cell list.

Parameters:
  • pos_a (jax.Array) – Query positions, shape (N_A, dim).

  • pos_b (jax.Array) – Database positions, shape (N_B, dim).

  • system (System) – The configuration of the simulation.

  • cutoff (float) – Search radius.

  • max_neighbors (int) – Maximum number of neighbors to store per query point.

Returns:

A tuple containing:

  • neighbor_list: Array of shape (N_A, max_neighbors) containing indices into pos_b, padded with -1.

  • overflow: Boolean flag indicating if any query point exceeded max_neighbors neighbors within the cutoff.

Return type:

Tuple[jax.Array, jax.Array]

class jaxdem.colliders.NaiveSimulator#

Bases: Collider

Implementation that computes forces and potential energies using a naive \(O(N^2)\) all-pairs interaction loop.

Notes

Due to its \(O(N^2)\) complexity, NaiveSimulator is suitable for simulations with a relatively small number of particles. For larger systems, a more efficient spatial partitioning collider should be used. However, this collider should be the fastest option for small systems (\(<1k-5k\) spheres depending on the GPU).

static compute_potential_energy(state: State, system: System) jax.Array[source]#

Computes the potential energy associated with each particle using a naive \(O(N^2)\) all-pairs loop.

This method iterates over all particle pairs (i, j) and sums the potential energy contributions computed by the system.force_model.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

Scalar containing the total potential energy of the system.

Return type:

jax.Array

static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#

Computes a neighbor list using a naive \(O(N^2)\) all-pairs search.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

  • cutoff (float) – The interaction radius (force cutoff).

  • max_neighbors (int) – Maximum number of neighbors to store per particle.

Returns:

A tuple containing: - state: The simulation state. - system: The simulation system. - neighbor_list: Array of shape (N, max_neighbors) containing neighbor indices. - overflow: Boolean flag indicating if any particle exceeded max_neighbors.

Return type:

Tuple[State, System, jax.Array, jax.Array]

static compute_force(state: State, system: System) tuple[State, System][source]#

Computes the total force acting on each particle using a naive \(O(N^2)\) all-pairs loop.

This method sums the force contributions from all particle pairs (i, j) as computed by the system.force_model and updates the particle forces.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object with computed forces and the unmodified System object.

Return type:

Tuple[State, System]

class jaxdem.colliders.NeighborList(cell_list: Collider, neighbor_list: Array, old_pos: Array, n_build_times: Array, cutoff: Array, skin: Array, overflow: Array, max_neighbors: int)#

Bases: Collider

Implementation of a Verlet neighbor list collider.

This collider caches a list of neighbors for every particle, significantly reducing the number of distance calculations required at each time step. The list is only rebuilt when the maximum displacement of any particle exceeds half of the specified skin distance.

Complexity#

  • Time: \(O(N)\) between rebuilds. Rebuild complexity depends on the underlying cell_list collider (typically \(O(N \log N)\)).

Notes

You must provide a non-zero skin (e.g., 0.1 * radius) for this collider to be efficient. If skin = 0, the list is rebuilt every step, which is computationally expensive.

cell_list: Collider#

The underlying collider used to build the list via create_neighbor_list.

neighbor_list: Array#

Shape (N, max_neighbors). Contains the IDs of neighboring particles, padded with -1.

old_pos: Array#

Shape (N, dim). Positions of particles at the last build time.

n_build_times: Array#

Counter for how many times the list has been rebuilt.

cutoff: Array#

The interaction radius (force cutoff).

skin: Array#

Buffer distance. The list is built with radius = cutoff + skin and rebuilt when max_displacement > skin / 2.

overflow: Array#

Boolean flag indicating if the neighbor list overflowed during build.

max_neighbors: int#

Static buffer size for the neighbor list.

classmethod Create(state: State, cutoff: float, skin: float = 0.05, max_neighbors: int | None = None, number_density: float = 1.0, safety_factor: float = 1.2, secondary_collider_type: str = 'CellList', secondary_collider_kw: dict[str, Any] | None = None) Self[source]#

Creates a NeighborList collider.

Parameters:
  • state (State) – The initial simulation state used to determine system dimensions and particle count.

  • cutoff (float) – The physical interaction cutoff radius.

  • skin (float, default 0.05) – The buffer distance added to the cutoff for the neighbor list. Must be > 0.0 for performance.

  • max_neighbors (int, optional) – Maximum number of neighbors to store per particle. If not provided, it is estimated from the number_density.

  • number_density (float, default 1.0) – Number density of the system used to estimate max_neighbors if not explicitly provided.

  • safety_factor (float, default 1.2) – Multiplier applied to the estimated number of neighbors to account for fluctuations in local density.

  • secondary_collider_type (str, default "CellList") – Registered collider type used internally to build the neighbor lists.

  • secondary_collider_kw (dict[str, Any], optional) – Keyword arguments passed to the constructor of the internal collider. If None, cell_size is set to cutoff + skin.

Returns:

A configured NeighborList collider instance.

Return type:

NeighborList

static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#

Returns the cached neighbor list from this collider.

This method does not rebuild the neighbor list. It returns the current cached neighbor_list and overflow flag stored in the collider.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

  • cutoff (float) – Ignored; provided for API compatibility.

  • max_neighbors (int) – Ignored; provided for API compatibility.

Returns:

A tuple containing:

  • state: The simulation state.

  • system: The simulation system.

  • neighbor_list: The cached neighbor list of shape (N, max_neighbors).

  • overflow: Boolean flag indicating if the list overflowed during the last build.

Return type:

Tuple[State, System, jax.Array, jax.Array]

Notes

  • The returned neighbor indices refer to the internal particle ordering established during the most recent rebuild inside compute_force.

static compute_force(state: State, system: System) tuple[State, System][source]#

Computes total forces acting on each particle, rebuilding the neighbor list if necessary.

This method checks if any particle has moved enough to trigger a rebuild (displacement > skin/2). If so, it invokes the internal spatial partitioner to refresh the neighbor list. It then sums force contributions using the cached list.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object with computed forces and the updated System object (with refreshed collider cache).

Return type:

Tuple[State, System]

static compute_potential_energy(state: State, system: System) jax.Array[source]#

Computes the potential energy associated with each particle using the cached neighbor list.

This method iterates over the cached neighbors for each particle and sums the potential energy contributions computed by the system.force_model.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

Scalar containing the total potential energy of the system.

Return type:

jax.Array

static create_cross_neighbor_list(pos_a: jax.Array, pos_b: jax.Array, system: System, cutoff: float, max_neighbors: int) tuple[jax.Array, jax.Array][source]#

Build a cross-neighbor list between two sets of positions.

Delegates to the internal cell_list collider’s create_cross_neighbor_list method.

Parameters:
  • pos_a (jax.Array) – Query positions, shape (N_A, dim).

  • pos_b (jax.Array) – Database positions, shape (N_B, dim).

  • system (System) – The configuration of the simulation (used for domain displacement).

  • cutoff (float) – Search radius.

  • max_neighbors (int) – Maximum number of neighbors to store per query point.

Returns:

A tuple containing:

  • neighbor_list: Array of shape (N_A, max_neighbors) containing indices into pos_b, padded with -1.

  • overflow: Boolean flag indicating if any query point exceeded max_neighbors neighbors within the cutoff.

Return type:

Tuple[jax.Array, jax.Array]

class jaxdem.colliders.StaticCellList(neighbor_mask: Array, cell_size: Array, max_occupancy: int)#

Bases: Collider

Implicit cell-list (spatial hashing) collider.

This collider accelerates short-range pair interactions by partitioning the domain into a regular grid of cubic/square cells of side length cell_size. Each particle is assigned to a cell, particles are sorted by cell hash, and interactions are evaluated only against particles in the same or neighboring cells given by neighbor_mask. The cell list is implicit because we never store per-cell particle lists explicitly; instead, we exploit the sorted hashes and fixed max_occupancy to probe neighbors in-place.

The operation of this collider can be understood as the following nested loop:

for particle in particles:
    for hash in stencil(particle):
        for i in range(max_occupancy):
            ...

The loop is flattened and everything happens in parallel. We can do this because the stencil size and max_occupancy are static values. This means that the cost of the calculations done with this cell list is:

\[O(N \cdot \text{neighbor\_mask\_size} \cdot \text{max\_occupancy})\]

Plus the cost of hashing \(O(N)\) and sorting \(O(N \log N)\). As we sort every time step, the system remains mainly sorted, reducing the practical sorting complexity to something closer to \(O(N)\). This fixed max_occupancy makes the cell list ideal for certain types of systems but very bad for others. To understand the difference, let’s analyze the cost components:

  • Stencil size:

    The stencil size depends on the ratio between the cell size (\(L\)) and the radius of the largest particle (\(r_{max}\)).

    \[\text{neighbor\_mask\_size} = \left( 2\left\lceil \frac{2r_{max}}{L} \right\rceil + 1 \right)^{dim}\]
  • Max occupancy:

    The maximum number of particles that can occupy a cell depends on the cell volume, the density (represented as local volume fraction \(\phi\)), and the volume of the smallest particle (\(V_{min}\)). To prevent missing contacts during local density fluctuations, we estimate the expected occupancy (\(\lambda\)) and add a safety margin of 3 standard deviations:

    \[\text{max\_occupancy} = \left\lceil \lambda + 3\sqrt{\lambda} \right\rceil \quad \text{where} \quad \lambda = \phi \frac{L^{dim}}{V_{min}}\]

    Here, \(\phi\) is the local volume fraction, representing the ratio of the volume actually occupied by particles to the total volume of the cell. For dense packings of spheres, \(\phi \approx 0.74\), but in systems with highly overlapping internal spheres (like rigid clumps), \(\phi\) can be much larger.

Putting this together and expressing the cell size in units of the maximum radius \(L^\prime = L/r_{max}\), we find the total theoretical cost to be:

\[\text{cost} \approx N \left( 2\left\lceil \frac{2}{L^\prime} \right\rceil + 1 \right)^{dim} \left\lceil \phi \frac{L^{dim}}{V_{min}} + 3\sqrt{\phi \frac{L^{dim}}{V_{min}}} \right\rceil\]

Then, if we define polydispersity as the ratio between the largest and smallest particle \(\alpha = r_{max}/r_{min}\), we can express \(\lambda\) solely in terms of these dimensionless parameters. Knowing that the volume of the smallest particle is \(V_{min} = k_v r_{min}^{dim}\) (where \(k_v\) is the geometric volume factor, such as \(4\pi/3\) in 3D or \(\pi\) in 2D), we can write \(\lambda\) as:

\[\lambda = \frac{\phi}{k_v} (\alpha L^\prime)^{dim}\]

Substituting this back, we find the final cost function:

\[\text{cost} \approx N \left( 2\left\lceil \frac{2}{L^\prime} \right\rceil + 1 \right)^{dim} \left\lceil \frac{\phi}{k_v} (\alpha L^\prime)^{dim} + 3\sqrt{\frac{\phi}{k_v}} (\alpha L^\prime)^{dim/2} \right\rceil\]
  • The Polydispersity Penalty:

    Because we must safely bound the maximum possible number of particles in a cell, we base \(\lambda\) on the smallest particle. As shown in the equation above, the required array padding grows dramatically as \(O(\alpha^{dim})\), causing performance to degrade in highly polydisperse systems.

This collider is ideal for systems of spheres with minimum polydispersity and no dramatic overlaps. In those cases, it might be even faster than the dynamic cell list. However, it’s not recommended for systems with clumps (where internal overlaps cause extreme local \(\phi\)) or high polydispersity, as both drastically inflate the required max_occupancy padding.

Complexity#

  • Time: \(O(N)\) - \(O(N \log N)\) from sorting, plus \(O(N \cdot M \cdot K)\) for neighbor probing (M = neighbor_mask_size, K = max_occupancy). The state is close to sorted every frame.

  • Memory: \(O(N)\).

cell_size#

Linear size of a grid cell (scalar).

Type:

jax.Array

max_occupancy#

Maximum number of particles assumed to occupy a single cell. The algorithm probes exactly max_occupancy entries starting from the first particle in a neighbor cell. This should be set high enough that real cells rarely exceed it; otherwise, contacts and energy will be undercounted.

Type:

int

Notes

  • max_occupancy is an upper bound on particles per cell. If a cell contains more than this many particles, some interactions might be missed (you should carefully choose cell_size and ensure your local density does not exceed your expected max_occupancy).

neighbor_mask: Array#

Integer offsets defining the neighbor stencil.

Shape is (M, dim), where each row is a displacement in cell coordinates. For search_range=1 in 2D this is the 3×3 Moore neighborhood (M=9); in 3D this is the 3×3×3 neighborhood (M=27).

cell_size: Array#

Linear size of a grid cell (scalar).

max_occupancy: int#

Maximum number of particles assumed to occupy a single cell.

The algorithm probes exactly max_occupancy entries starting from the first particle in a neighbor cell. This should be set high enough that real cells rarely exceed it; otherwise contacts/energy will be undercounted.

classmethod Create(state: State, cell_size: ArrayLike | None = None, search_range: ArrayLike | None = None, box_size: ArrayLike | None = None, max_occupancy: int | None = None) Self[source]#

Creates a StaticCellList collider with robust defaults.

Defaults are chosen to prevent missing contacts while keeping the neighbor stencil and assumed cell occupancy as small as possible given the available information from state.

The optimal cell size parameter, denoted here as the dimensionless ratio \(L^\prime = L/r_{max}\), depends heavily on the system’s volume fraction (\(\phi\)) and polydispersity (\(\alpha = r_{max}/r_{min}\)):

  • Standard density ( \(\phi \le 1\) ): The optimal cell size is typically \(L^\prime = 2\) (the diameter of the largest particle). This minimizes the search stencil to just the immediate neighboring cells.

  • Extreme local density ( \(\phi \gg 1\) ): For systems with heavy internal overlaps (like rigid clumps), the optimal cell size shrinks. Values like \(L^\prime = 0.5\) or \(L^\prime = 0.25\) often yield better performance by reducing the massive array padding penalty, even at the cost of a larger search stencil.

  • High polydispersity ( \(\alpha \gg 1\) ): High polydispersity severely degrades performance regardless of density, because the fixed occupancy arrays must always be padded to accommodate the volume of the smallest particles.

By default, if cell_size or max_occupancy are not provided, this method infers optimal safe values based on the radius distribution in the reference state, assuming a maximum local volume fraction of \(\phi = 1\). If your system contains clumps with internal overlaps where local \(\phi > 1\), you must override these defaults.

Parameters:
  • state (State) – Reference state used to determine spatial dimension and default parameters.

  • cell_size (float, optional) – Cell edge length. If None, defaults to \(2 r_{max}\) for systems with low polydispersity (\(\alpha < 2.5\)), or \(0.5 r_{max}\) for highly polydisperse systems to mitigate exponential array padding costs.

  • search_range (int, optional) – Neighbor range in cell units. If None, the smallest safe value is computed such that \(\text{search\_range} \cdot L \geq \text{cutoff}\).

  • box_size (jax.Array, optional) – Size of the periodic box used to ensure there are at least 2 * search_range + 1 cells per axis. If None, these checks are skipped.

  • max_occupancy (int, optional) – Assumed maximum particles per cell. If None, estimated using the statistical model: \(\lambda + 3\sqrt{\lambda}\), assuming a worst-case standard granular density of \(\phi = 1\) and the volume of the smallest particle.

Returns:

Configured collider instance.

Return type:

CellList

static compute_force(state: State, system: System) tuple[State, System][source]#

Computes the total force acting on each particle using an implicit cell list \(O(N log N)\). This method sums the force contributions from all particle pairs (i, j) as computed by the system.force_model and updates the particle forces.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object with computed forces and the unmodified System object.

Return type:

Tuple[State, System]

static compute_potential_energy(state: State, system: System) jax.Array[source]#

Computes the potential energy acting on each particle using an implicit cell list \(O(N log N)\). This method sums the energy contributions from all particle pairs (i, j) as computed by the system.force_model.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

Scalar containing the total potential energy of the system.

Return type:

jax.Array

static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#

Computes the list of neighbors for each particle. The shape of the list is (N, max_neighbors). If a particle has less neighbors than max_neighbors, the list is padded with -1. The indices of the list correspond to the indices of the returned sorted state.

Note that no neighbors further than cell_size * (1 + search_range) (how many neighbors to check in the cells) can be found due to the nature of the cell list. If cutoff is greater than this value, the list might not return the expected list. Note that if a cell contains more spheres than those specified in max_occupancy, there might be missing neighbors.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

  • cutoff (float) – Search radius

  • max_neighbors (int) – Maximum number of neighbors to store per particle.

Returns:

The sorted state, the system, the neighbor list, and a boolean flag for overflow.

Return type:

tuple[State, System, jax.Array, jax.Array]

static create_cross_neighbor_list(pos_a: jax.Array, pos_b: jax.Array, system: System, cutoff: float, max_neighbors: int) tuple[jax.Array, jax.Array][source]#

Build a cross-neighbor list between two sets of positions using an implicit cell list.

For each point in pos_a, finds all neighbors from pos_b within the given cutoff distance. The pos_b array is hashed and sorted into cells, and the neighbor stencil of each query point in pos_a is used to probe the sorted pos_b hashes with a fixed-size unrolled loop.

No neighbors further than cell_size * (1 + search_range) can be found due to the nature of the cell list. If a cell contains more particles than max_occupancy, some neighbors may be missed.

Parameters:
  • pos_a (jax.Array) – Query positions, shape (N_A, dim).

  • pos_b (jax.Array) – Database positions, shape (N_B, dim).

  • system (System) – The configuration of the simulation.

  • cutoff (float) – Search radius.

  • max_neighbors (int) – Maximum number of neighbors to store per query point.

Returns:

A tuple containing:

  • neighbor_list: Array of shape (N_A, max_neighbors) containing indices into pos_b, padded with -1.

  • overflow: Boolean flag indicating if any query point exceeded max_neighbors neighbors within the cutoff.

Return type:

Tuple[jax.Array, jax.Array]

jaxdem.colliders.valid_interaction_mask(clump_i: Array, clump_j: Array, bond_i: Array, bond_j: Array, interact_same_bond_id: Array) Array[source]#

Pair mask shared by all colliders.

Interactions are always disabled for particles in the same clump. Interactions for particles with equal bond_id are controlled by interact_same_bond_id.

Modules

cell_list

Cell List \(O(N log N)\) collider implementation.

naive

Naive \(O(N^2)\) collider implementation.

neighbor_list

Neighbor List Collider implementation.

sweep_and_prune

Sweep and prune \(O(N log N)\) collider implementation.