Source code for jaxdem.colliders.neighbor_list

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project - https://github.com/cdelv/JaxDEM
"""Neighbor List Collider implementation."""

from __future__ import annotations

import jax
import jax.numpy as jnp

from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, cast
from functools import partial

try:
    from typing import Self
except ImportError:
    from typing_extensions import Self

from . import Collider, DynamicCellList

if TYPE_CHECKING:
    from ..state import State
    from ..system import System


[docs] @Collider.register("NeighborList") @jax.tree_util.register_dataclass @dataclass(slots=True) class NeighborList(Collider): r""" Verlet Neighbor List collider. This collider caches a list of neighbors for every particle. It only rebuilds the list when particles have moved more than half the 'skin' distance. **Performance Note:** You must provide a non-zero `skin` (e.g., 0.1 * radius) for this collider to be efficient. If `skin=0`, it rebuilds every step. Attributes ---------- cell_list : DynamicCellList The underlying spatial partitioner used to build the list. neighbor_list : jax.Array Shape (N, max_neighbors). Contains the IDs of neighboring particles. padded with -1. old_pos : jax.Array Shape (N, dim). Positions of particles at the last build time. n_build_times : int Counter for how many times the list has been rebuilt. cutoff : float The interaction radius (force cutoff). skin : float Buffer distance. The list is built with `radius = cutoff + skin` and rebuilt when `max_displacement > skin / 2`. overflow : jax.Array Boolean flag indicating if the neighbor list overflowed during build. max_neighbors : int Static buffer size for the neighbor list. """ cell_list: DynamicCellList neighbor_list: jax.Array old_pos: jax.Array n_build_times: int cutoff: jax.Array skin: jax.Array overflow: jax.Array max_neighbors: int = field(metadata={"static": True})
[docs] @classmethod def Create( cls, state: State, cutoff: float, box_size: Optional[jax.Array] = None, skin: float = 0.05, max_neighbors: Optional[int] = None, number_density: float = 1.0, safety_factor: float = 1.2, cell_size: Optional[float] = None, ) -> Self: r""" Creates a NeighborList collider. Parameters ---------- state : State Initial simulation state. cutoff : float The physical interaction cutoff radius. box_size : jax.Array, optional The size of the periodic box, if used. skin : float, default 0.05 The buffer distance. **Must be > 0.0 for performance.** max_neighbors : int, optional Maximum neighbors to store per particle. If not provided, it is estimated from the number_density. number_density : float, default 1.0 Number density for the state used to calculate max_neighbors, if not provided. Assumed to be 1.0. safety_factor : float, default 1.2 Used to adjust the max_neighbors value calculated from number_density. Empirically obtained cell_size : float, optional Override for the underlying cell list size. """ skin *= cutoff list_cutoff = cutoff + skin if cell_size is None: cell_size = list_cutoff if max_neighbors is None: # estimate max_neighbors if it is not provided nl_volume = ( jnp.pi * (safety_factor * list_cutoff) ** state.dim * ((1) if state.dim == 2 else (4 / 3)) ) max_neighbors = max(int(nl_volume * number_density), 10) # Initialize inner CellList cl = DynamicCellList.Create(state, cell_size=cell_size, box_size=box_size) # Initialize buffers # We start with current positions. The n_build_times=0 flag will # force an immediate rebuild in the first compute_force call. current_pos = state.pos dummy_nl = jnp.full((state.N, max_neighbors), -1, dtype=int) return cls( cell_list=cl, neighbor_list=dummy_nl, old_pos=current_pos, n_build_times=0, cutoff=jnp.asarray(cutoff, dtype=float), skin=jnp.asarray(skin, dtype=float), overflow=jnp.asarray(False, dtype=bool), max_neighbors=int(max_neighbors), )
[docs] @staticmethod @jax.jit(static_argnames=("max_neighbors",)) @partial(jax.named_call, name="NeighborList.create_neighbor_list") def create_neighbor_list( state: State, system: System, cutoff: float, max_neighbors: int, ) -> Tuple[State, System, jax.Array, jax.Array]: """ Return the **cached** neighbor list from this collider. Notes ----- - This method does **not** rebuild the neighbor list. It simply returns the last cached ``neighbor_list`` and ``overflow`` stored in ``system.collider``. - The returned neighbor indices refer to the collider's internal particle ordering at the time the cache was last updated (i.e., after the most recent rebuild inside :meth:`compute_force`). - The ``cutoff`` and ``max_neighbors`` arguments are accepted for API compatibility but are currently ignored; the cache was built using this collider's configured ``cutoff + skin`` and ``max_neighbors``. """ collider = cast(NeighborList, system.collider) return state, system, collider.neighbor_list, collider.overflow
@staticmethod @partial(jax.named_call, name="NeighborList._rebuild") def _rebuild( collider: NeighborList, state: State, system: System ) -> Tuple[State, jax.Array, jax.Array, int, jax.Array]: """ Static internal method to rebuild the neighbor list. """ list_cutoff = collider.cutoff + collider.skin # Create a view of the system using the inner collider system.collider = collider.cell_list # 1. Get neighbors using the spatial partitioner # Returns: Sorted State, ..., Neighbors Indices (pointing to Sorted State) ( sorted_state, _, sorted_nl_indices, overflow_flag, ) = collider.cell_list.create_neighbor_list( state, system, list_cutoff, collider.max_neighbors ) # return the sorted state to avoid having to un-sort the neighbor list return ( sorted_state, sorted_nl_indices, sorted_state.pos, collider.n_build_times + 1, overflow_flag, )
[docs] @staticmethod @jax.jit(donate_argnames=("state", "system")) @partial(jax.named_call, name="NeighborList.compute_force") def compute_force(state: State, system: System) -> Tuple[State, System]: iota = jax.lax.iota(dtype=int, size=state.N) # should this be cached? collider = cast(NeighborList, system.collider) # 1. Check Displacement & Trigger Rebuild # disp = system.domain.displacement(state.pos, collider.old_pos, system) disp = state.pos - collider.old_pos # this should not be a periodic distance max_disp_sq = jnp.max(jnp.sum(disp * disp, axis=-1)) trigger_dist_sq = collider.skin**2 / 4 # Force rebuild if displacement is large OR if this is the first step (count == 0) should_rebuild = (max_disp_sq > trigger_dist_sq) + (collider.n_build_times == 0) def rebuild_branch( operands: Tuple[State, System, NeighborList], ) -> Tuple[State, jax.Array, jax.Array, int, jax.Array]: s, sys, col = operands return col._rebuild(col, s, sys) def no_rebuild_branch( operands: Tuple[State, System, NeighborList], ) -> Tuple[State, jax.Array, jax.Array, int, jax.Array]: _, _, col = operands return ( state, col.neighbor_list, col.old_pos, col.n_build_times, col.overflow, ) state, nl, old_pos, n_build, overflow = jax.lax.cond( should_rebuild > 0, rebuild_branch, no_rebuild_branch, (state, system, collider), ) # 2. Compute Forces # Pre-calculate contact points in global frame for torque pos_p_global = state.q.rotate(state.q, state.pos_p) pos = state.pos_c + pos_p_global def per_particle_force( i: jax.Array, pos_pi: jax.Array, neighbors: jax.Array ) -> Tuple[jax.Array, jax.Array]: # i: ID of the current particle # pos_pi: vector from COM to surface for particle i # neighbors: array of neighbor IDs def per_neighbor_force(j_id: jax.Array) -> Tuple[jax.Array, jax.Array]: # We mask computations for padding (-1) valid = j_id != -1 safe_j = jnp.maximum(j_id, 0) f, t = system.force_model.force(i, safe_j, pos, state, system) return f * valid, t * valid forces, torques = jax.vmap(per_neighbor_force)(neighbors) f_sum = jnp.sum(forces, axis=0) # Add contact torque: T_total = Sum(T_ij) + (r_i x F_total) t_sum = jnp.sum(torques, axis=0) + jnp.cross(pos_pi, f_sum) return f_sum, t_sum # Vmap over particle IDs [0, 1, ..., N] state.force, state.torque = jax.vmap(per_particle_force)(iota, pos_p_global, nl) # Update collider cache system.collider = replace( collider, neighbor_list=nl, old_pos=old_pos, n_build_times=n_build, overflow=overflow, ) return state, system
[docs] @staticmethod @jax.jit @partial(jax.named_call, name="NeighborList.compute_potential_energy") def compute_potential_energy(state: State, system: System) -> jax.Array: iota = jax.lax.iota(dtype=int, size=state.N) collider = cast(NeighborList, system.collider) def per_particle_energy(i: jax.Array) -> jax.Array: neighbors = collider.neighbor_list[i] def per_neighbor_energy(j_id: jax.Array) -> jax.Array: valid = j_id != -1 safe_j = jnp.maximum(j_id, 0) e = system.force_model.energy(i, safe_j, state.pos, state, system) return e * valid # Sum energies and divide by 2 (double counting in neighbor list) return 0.5 * jnp.sum(jax.vmap(per_neighbor_energy)(neighbors)) return jax.vmap(per_particle_energy)(iota)