jaxdem.colliders.neighbor_list#

Neighbor List Collider implementation.

Classes

NeighborList(cell_list, neighbor_list, ...)

Implementation of a Verlet neighbor list collider.

class jaxdem.colliders.neighbor_list.NeighborList(cell_list: Collider, neighbor_list: Array, old_pos: Array, n_build_times: Array, cutoff: Array, skin: Array, overflow: Array, max_neighbors: int)#

Bases: Collider

Implementation of a Verlet neighbor list collider.

This collider caches a list of neighbors for every particle, significantly reducing the number of distance calculations required at each time step. The list is only rebuilt when the maximum displacement of any particle exceeds half of the specified skin distance.

Complexity#

  • Time: \(O(N)\) between rebuilds. Rebuild complexity depends on the underlying cell_list collider (typically \(O(N \log N)\)).

Notes

You must provide a non-zero skin (e.g., 0.1 * radius) for this collider to be efficient. If skin = 0, the list is rebuilt every step, which is computationally expensive.

cell_list: Collider#

The underlying collider used to build the list via create_neighbor_list.

neighbor_list: Array#

Shape (N, max_neighbors). Contains the IDs of neighboring particles, padded with -1.

old_pos: Array#

Shape (N, dim). Positions of particles at the last build time.

n_build_times: Array#

Counter for how many times the list has been rebuilt.

cutoff: Array#

The interaction radius (force cutoff).

skin: Array#

Buffer distance. The list is built with radius = cutoff + skin and rebuilt when max_displacement > skin / 2.

overflow: Array#

Boolean flag indicating if the neighbor list overflowed during build.

max_neighbors: int#

Static buffer size for the neighbor list.

classmethod Create(state: State, cutoff: float, skin: float = 0.05, max_neighbors: int | None = None, number_density: float = 1.0, safety_factor: float = 1.2, secondary_collider_type: str = 'CellList', secondary_collider_kw: dict[str, Any] | None = None) Self[source]#

Creates a NeighborList collider.

Parameters:
  • state (State) – The initial simulation state used to determine system dimensions and particle count.

  • cutoff (float) – The physical interaction cutoff radius.

  • skin (float, default 0.05) – The buffer distance added to the cutoff for the neighbor list. Must be > 0.0 for performance.

  • max_neighbors (int, optional) – Maximum number of neighbors to store per particle. If not provided, it is estimated from the number_density.

  • number_density (float, default 1.0) – Number density of the system used to estimate max_neighbors if not explicitly provided.

  • safety_factor (float, default 1.2) – Multiplier applied to the estimated number of neighbors to account for fluctuations in local density.

  • secondary_collider_type (str, default "CellList") – Registered collider type used internally to build the neighbor lists.

  • secondary_collider_kw (dict[str, Any], optional) – Keyword arguments passed to the constructor of the internal collider. If None, cell_size is set to cutoff + skin.

Returns:

A configured NeighborList collider instance.

Return type:

NeighborList

static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#

Returns the cached neighbor list from this collider.

This method does not rebuild the neighbor list. It returns the current cached neighbor_list and overflow flag stored in the collider.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

  • cutoff (float) – Ignored; provided for API compatibility.

  • max_neighbors (int) – Ignored; provided for API compatibility.

Returns:

A tuple containing:

  • state: The simulation state.

  • system: The simulation system.

  • neighbor_list: The cached neighbor list of shape (N, max_neighbors).

  • overflow: Boolean flag indicating if the list overflowed during the last build.

Return type:

Tuple[State, System, jax.Array, jax.Array]

Notes

  • The returned neighbor indices refer to the internal particle ordering established during the most recent rebuild inside compute_force.

static compute_force(state: State, system: System) tuple[State, System][source]#

Computes total forces acting on each particle, rebuilding the neighbor list if necessary.

This method checks if any particle has moved enough to trigger a rebuild (displacement > skin/2). If so, it invokes the internal spatial partitioner to refresh the neighbor list. It then sums force contributions using the cached list.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object with computed forces and the updated System object (with refreshed collider cache).

Return type:

Tuple[State, System]

static compute_potential_energy(state: State, system: System) jax.Array[source]#

Computes the potential energy associated with each particle using the cached neighbor list.

This method iterates over the cached neighbors for each particle and sums the potential energy contributions computed by the system.force_model.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

Scalar containing the total potential energy of the system.

Return type:

jax.Array

static create_cross_neighbor_list(pos_a: jax.Array, pos_b: jax.Array, system: System, cutoff: float, max_neighbors: int) tuple[jax.Array, jax.Array][source]#

Build a cross-neighbor list between two sets of positions.

Delegates to the internal cell_list collider’s create_cross_neighbor_list method.

Parameters:
  • pos_a (jax.Array) – Query positions, shape (N_A, dim).

  • pos_b (jax.Array) – Database positions, shape (N_B, dim).

  • system (System) – The configuration of the simulation (used for domain displacement).

  • cutoff (float) – Search radius.

  • max_neighbors (int) – Maximum number of neighbors to store per query point.

Returns:

A tuple containing:

  • neighbor_list: Array of shape (N_A, max_neighbors) containing indices into pos_b, padded with -1.

  • overflow: Boolean flag indicating if any query point exceeded max_neighbors neighbors within the cutoff.

Return type:

Tuple[jax.Array, jax.Array]