# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Naive :math:`O(N^2)` collider implementation."""
from __future__ import annotations
import jax
import jax.numpy as jnp
from dataclasses import dataclass
from typing import Tuple, TYPE_CHECKING, cast
from functools import partial
from . import Collider
from ..utils.linalg import cross
if TYPE_CHECKING: # pragma: no cover
from ..state import State
from ..system import System
[docs]
@Collider.register("naive")
@jax.tree_util.register_dataclass
@dataclass(slots=True)
class NaiveSimulator(Collider):
r"""
Implementation that computes forces and potential energies using a naive :math:`O(N^2)` all-pairs interaction loop.
Notes
-----
Due to its :math:`O(N^2)` complexity, `NaiveSimulator` is suitable for simulations
with a relatively small number of particles. For larger systems, a more
efficient spatial partitioning collider should be used. However, this collider should be the fastest
option for small systems (:math:`<1k-5k` spheres depending on the GPU).
"""
[docs]
@staticmethod
@jax.jit
@partial(jax.named_call, name="NaiveSimulator.compute_potential_energy")
def compute_potential_energy(state: "State", system: "System") -> jax.Array:
r"""
Computes the potential energy associated with each particle using a naive :math:`O(N^2)` all-pairs loop.
This method iterates over all particle pairs (i, j) and sums the potential energy
contributions computed by the ``system.force_model``.
Parameters
----------
state : State
The current state of the simulation.
system : System
The configuration of the simulation.
Returns
-------
jax.Array
One-dimensional array containing the total potential energy contribution for each particle.
Note
-----
- This method donates state and system
"""
iota = jax.lax.iota(dtype=int, size=state.N)
def row_energy(i: jax.Array, st: "State", sys: "System") -> jax.Array:
j = jax.lax.iota(dtype=int, size=st.N)
e_ij = jax.vmap(sys.force_model.energy, in_axes=(None, 0, None, None))(
i, j, st, sys
)
mask = st.ID[i] != st.ID[j]
e_ij *= mask
return 0.5 * e_ij.sum(axis=0)
return jax.vmap(row_energy, in_axes=(0, None, None))(iota, state, system)
[docs]
@staticmethod
@partial(jax.jit, donate_argnames=("state", "system"), inline=True)
@partial(jax.named_call, name="NaiveSimulator.compute_force")
def compute_force(state: "State", system: "System") -> Tuple["State", "System"]:
r"""
Computes the total force acting on each particle using a naive :math:`O(N^2)` all-pairs loop.
This method sums the force contributions from all particle pairs (i, j)
as computed by the ``system.force_model`` and updates the particle forces.
Parameters
----------
state : State
The current state of the simulation.
system : System
The configuration of the simulation.
Returns
-------
Tuple[State, System]
A tuple containing the updated ``State`` object with computed forces
and the unmodified ``System`` object.
Note
-----
- This method donates state and system
"""
iota = jax.lax.iota(dtype=int, size=state.N)
pos_p = state.q.rotate(state.q, state.pos_p) # to lab
def pairwise_accumulate(
i: jax.Array, j: jax.Array, pos_pi: jax.Array, st: "State", sys: "System"
) -> Tuple[jax.Array, jax.Array]:
forces, torques = jax.vmap(
sys.force_model.force, in_axes=(None, 0, None, None)
)(i, j, st, sys)
mask = (st.ID[i] != st.ID[j])[..., None]
forces *= mask
torques *= mask
torques += cross(pos_pi, forces)
return forces.sum(axis=0), torques.sum(axis=0)
total_force, total_torque = jax.vmap(
pairwise_accumulate, in_axes=(0, None, 0, None, None)
)(iota, iota, pos_p, state, system)
total_torque = jax.ops.segment_sum(total_torque, state.ID, num_segments=state.N)
total_force = jax.ops.segment_sum(total_force, state.ID, num_segments=state.N)
state.force += total_force[state.ID]
state.torque += total_torque[state.ID]
return state, system
__all__ = ["NaiveSimulator"]