# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Sweep and prune :math:`O(N log N)` collider implementation."""
from __future__ import annotations
import jax
import jax.numpy as jnp
import jax.experimental.pallas as pl
from jax import tree_util
from dataclasses import dataclass, field
from typing import Tuple, TYPE_CHECKING, cast
from functools import partial
from . import Collider
if TYPE_CHECKING: # pragma: no cover
from ..state import State
from ..system import System
[docs]
@jax.jit
@partial(jax.named_call, name="pad_to_power2")
def pad_to_power2(x):
"""
Pad 3D simulations to 4D (Pallas Kernel limitations)
"""
if x.ndim != 2:
return x
n, dim = x.shape
target_dim = dim + dim % 2
return jnp.pad(x, ((0, 0), (0, target_dim - dim)), constant_values=0.0)
from dataclasses import replace
from dataclasses import replace
# 1. Define RefProxy (Make sure this is in your file, outside the kernel)
# --- 1. Define and Register RefProxy ---
# Place this at the top level of your file (before the kernel)
[docs]
class RefProxy:
"""Wraps a Pallas Ref to behave like a JAX Array with auto-loading."""
def __init__(self, ref):
self.ref = ref
# Properties to mimic array attributes
@property
def shape(self):
return self.ref.shape
@property
def dtype(self):
return self.ref.dtype
@property
def ndim(self):
return len(self.ref.shape)
# Auto-load on indexing
def __getitem__(self, idx):
# Heuristic: If 2D and index is scalar, load row. Otherwise load element.
if self.ndim == 2 and not isinstance(idx, tuple):
return pl.load(self.ref, (idx, slice(None)))
return pl.load(self.ref, (idx,) if not isinstance(idx, tuple) else idx)
# Helper to load the full array for math ops
def _load_all(self):
return pl.load(self.ref, (slice(None),) * self.ndim)
# Math operators (RefProxy + x, x / RefProxy, etc.)
def __add__(self, o):
return self._load_all() + o
def __radd__(self, o):
return o + self._load_all()
def __sub__(self, o):
return self._load_all() - o
def __rsub__(self, o):
return o - self._load_all()
def __mul__(self, o):
return self._load_all() * o
def __rmul__(self, o):
return o * self._load_all()
def __truediv__(self, o):
return self._load_all() / o
def __rtruediv__(self, o):
return o / self._load_all()
def __neg__(self):
return -self._load_all()
# Catch-all for other attributes (like .T)
def __getattr__(self, name):
return getattr(self._load_all(), name)
# --- CRITICAL: JAX PyTree Registration ---
def _ref_proxy_flatten(proxy):
# Tell JAX: "I contain this ref"
return (proxy.ref,), None
def _ref_proxy_unflatten(aux, children):
# Tell JAX: "Rebuild me using this ref"
return RefProxy(children[0])
jax.tree_util.register_pytree_node(RefProxy, _ref_proxy_flatten, _ref_proxy_unflatten)
# --- 2. The Universal Kernel ---
[docs]
@partial(jax.profiler.annotate_function, name="sap_kernel_full")
def sap_kernel_full(
state_ref, system_ref, aabb_ref, m_ref, M_ref, HASH_ref, forces_ref
):
i = pl.num_programs(1) * pl.program_id(0) + pl.program_id(1)
# Wrap Inputs in Proxies
# We use duck typing (hasattr "shape") to find Refs, as pl.Ref isn't always exposed
state_proxy = jax.tree.map(
lambda x: RefProxy(x) if hasattr(x, "shape") else x, state_ref
)
system_proxy = jax.tree.map(
lambda x: RefProxy(x) if hasattr(x, "shape") else x, system_ref
)
M_i = pl.load(M_ref, (i,))
pos_i = pl.load(state_ref.pos, (i, slice(None)))
aabb_i = pl.load(aabb_ref, (i, slice(None)))
pl.store(forces_ref, (i, slice(None)), jnp.zeros_like(pos_i))
def cond(j):
n = state_ref.pos.shape[0]
return (j < n) * (pl.load(m_ref, (j,)) <= M_i)
def body(j):
pos_j = pl.load(state_ref.pos, (j, slice(None)))
aabb_j = pl.load(aabb_ref, (j, slice(None)))
# Now works seamlessly with JIT-ed displacement
r_ij = system_proxy.domain.displacement(pos_i, pos_j, system_proxy)
overlap = jnp.sum(jnp.abs(r_ij) <= (aabb_i + aabb_j)) == state_ref.pos.shape[1]
def compute_force_wrapper(_):
# Now works seamlessly with JIT-ed force
# JAX flattens system_proxy -> sees Ref -> passes Ref -> rebuilds RefProxy inside
return system_ref.force_model.force(state_proxy, system_proxy, i, j)[0]
f = jax.lax.cond(
overlap,
compute_force_wrapper,
lambda _: jnp.zeros_like(pos_i),
operand=None,
)
pl.atomic_add(forces_ref, (i, slice(None)), f)
pl.atomic_add(forces_ref, (j, slice(None)), -f)
return j + 1
jax.lax.while_loop(cond, body, i + 1)
[docs]
@jax.jit
@partial(jax.profiler.annotate_function, name="compute_hash")
def compute_hash(state, proj_perp, aabb, shift):
cell_size = 4 * jnp.max(aabb)
proj_min = proj_perp.min(axis=0)
proj_max = proj_perp.max(axis=0)
grid_dims = jnp.maximum(
1, jnp.ceil((proj_max - proj_min + 2 * cell_size) / cell_size).astype(int)
)
multipliers = jnp.concatenate([jnp.ones(1, dtype=int), jnp.cumprod(grid_dims[:-1])])
cell_idx = jnp.floor((proj_perp + shift * cell_size / 2) / cell_size).astype(int)
return jnp.dot(cell_idx, multipliers)
[docs]
@jax.jit
@partial(jax.profiler.annotate_function, name="compute_virtual_shift")
def compute_virtual_shift(m, M, HASH):
shift = M.max() - m.min()
virtual_shift1 = 2 * HASH * shift
return m + virtual_shift1, M + virtual_shift1
[docs]
@jax.jit
@partial(jax.profiler.annotate_function, name="sort")
def sort(state, iota, m, M):
m, M, perm = jax.lax.sort([m, M, iota], num_keys=1)
state = tree_util.tree_map(lambda x: x[perm], state)
return state, m, M, perm
[docs]
@jax.jit
@partial(jax.profiler.annotate_function, name="padd")
def padd(state):
return tree_util.tree_map(pad_to_power2, state)
[docs]
@partial(jax.jit, inline=True)
@partial(jax.named_call, name="SpringForce.force")
def force(
i: int, j: int, state: "State", system: "System"
) -> Tuple[jax.Array, jax.Array]:
# 1. Load data from Refs
mat_id_i = pl.load(state.mat_id, (i,))
mat_id_j = pl.load(state.mat_id, (j,))
rad_i = pl.load(state.rad, (i,))
rad_j = pl.load(state.rad, (j,))
pos_i = pl.load(state.pos, (i, slice(None)))
pos_j = pl.load(state.pos, (j, slice(None)))
# 2. Lookup Stiffness (Must use pl.load on the Ref)
k = pl.load(system.mat_table.young_eff, (mat_id_i, mat_id_j))
# 3. Calculate Displacement
# (Works because system.domain.box_size is a Value in the hybrid object)
rij = system.domain.displacement(pos_i, pos_j, system)
# --- THE FIX FOR THE ERROR ---
# Old: r = jnp.vecdot(rij, rij) <-- Causes "must be 2D" error in Triton
# New: Element-wise multiply + sum
r_sq = jnp.sum(rij * rij)
r = jnp.sqrt(r_sq + jnp.finfo(pos_i.dtype).eps)
# 4. Force Calculation
R = rad_i + rad_j
s = R / r - 1.0
s *= s > 0
return k * s * rij, jnp.zeros_like(pos_i)
[docs]
@Collider.register("sap")
@jax.tree_util.register_dataclass
@dataclass(slots=True)
class SweeAPrune(Collider):
[docs]
@staticmethod
@jax.jit
@partial(jax.named_call, name="SweeAPrune.compute_potential_energy")
def compute_potential_energy(state: "State", system: "System"):
pass
[docs]
@staticmethod
@partial(jax.jit, donate_argnames=("state", "system"))
@partial(jax.named_call, name="SweeAPrune.compute_force")
def compute_force(state: "State", system: "System") -> Tuple["State", "System"]:
aabb = state.rad[:, None] * jnp.ones((1, state.pos.shape[1]))
chunk_size = 16
n, dim = state.pos.shape
iota = jax.lax.iota(int, n)
quantization = jnp.min(aabb) / 4.0
m = state.pos - aabb
M = state.pos + aabb
# 1) PCA sweep direction
# v, v_perp = PCA_decomposition(m)
I = jnp.eye(dim, dtype=state.pos.dtype)
v = I[:, 0]
v_perp = I[:, 1:]
# 2) Project into perpendicular plane
proj_perp = jnp.dot(m, v_perp)
# 3) Create grid in perpendicular directions
HASH1 = compute_hash(state, proj_perp, aabb, 0.0)
HASH2 = compute_hash(state, proj_perp, aabb, 1.0)
# project into sweeping direction and quantizatize to integers for performance
m = (jnp.dot(m / quantization, v)).astype(int)
M = (jnp.dot(M / quantization, v)).astype(int)
m1, M1 = compute_virtual_shift(m, M, HASH1)
m2, M2 = compute_virtual_shift(m, M, HASH2)
# Sort particles by shifted sweep coordinates
state1, m1, M1, perm1 = sort(state, iota, m1, M1)
state2, m2, M2, perm2 = sort(state, iota, m2, M2)
# First SaP pass - compute all interactions in the cell
state_padded1 = padd(state1)
state_padded1.force = pl.pallas_call(
sap_kernel_full,
out_shape=state_padded1.force,
grid=(n // chunk_size + 1, chunk_size),
interpret=False,
name="First pass",
)(state_padded1, system, aabb, m1, M1, iota)
# Second SaP pass - skip same hash interactions
state_padded2 = padd(state2)
state_padded2.force = pl.pallas_call(
sap_kernel_full,
out_shape=state_padded2.force,
grid=(n // chunk_size + 1, chunk_size),
interpret=False,
name="Second pass",
)(state_padded2, system, aabb, m2, M2, HASH1[perm2])
# Combine forces and unpermute
perm2 = perm2.at[perm2].set(iota)
state_padded2.force = state_padded2.force[:, :dim][perm2] # unpadd
state1.force = (
state_padded1.force[:, :dim] + state_padded2.force[perm1]
) / state_padded1.mass[:, None]
return state1, system