Source code for jaxdem.colliders
# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Collision-detection interfaces and implementations."""
from __future__ import annotations
import jax
import jax.numpy as jnp
from abc import ABC
from dataclasses import dataclass
from typing import Tuple, TYPE_CHECKING
from functools import partial
from ..factory import Factory
if TYPE_CHECKING: # pragma: no cover
from ..state import State
from ..system import System
[docs]
@jax.tree_util.register_dataclass
@dataclass(slots=True)
class Collider(Factory, ABC):
r"""
The base interface for defining how contact detection and force computations are performed in a simulation.
Concrete subclasses of `Collider` implement the specific algorithms for calculating the interactions.
Notes
-----
Self-interaction (i.e., calling the force/energy computation for `i=j`) is allowed,
and the underlying `force_model` is responsible for correctly handling or
ignoring this case.
Example
-------
To define a custom collider, inherit from `Collider`, register it and implement its abstract methods:
>>> @Collider.register("CustomCollider")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class CustomCollider(Collider):
...
Then, instantiate it:
>>> jaxdem.Collider.create("CustomCollider", **custom_collider_kw)
"""
[docs]
@staticmethod
@partial(jax.jit, donate_argnames=("state", "system"), inline=True)
def compute_force(state: "State", system: "System") -> Tuple["State", "System"]:
"""
Abstract method to compute the total force acting on each particle in the simulation.
Implementations should calculate inter-particle forces and torques based on the current
`state` and `system` configuration, then update the `force` and `torque` attributes of the
`state` object with the resulting total force and torque for each particle.
Parameters
----------
state : State
The current state of the simulation.
system : System
The configuration of the simulation.
Returns
-------
Tuple[State, System]
A tuple containing the updated `State` object (with computed forces)
and the `System` object.
Note
-----
- This method donates state and system
"""
return state, system
[docs]
@staticmethod
@jax.jit
def compute_potential_energy(state: "State", system: "System") -> jax.Array:
"""
Abstract method to compute the total potential energy of the system.
Implementations should calculate the sum per particle of all potential energies
present in the system based on the current `state` and `system` configuration.
Parameters
----------
state : State
The current state of the simulation.
system : System
The configuration of the simulation.
Returns
-------
jax.Array
A scalar JAX array representing the total potential energy of each particle.
Example
-------
>>> potential_energy = system.collider.compute_potential_energy(state, system)
>>> print(f"Potential energy per particle: {potential_energy:.4f}")
>>> print(potential_energy.shape") # (N, 1)
"""
return jnp.zeros_like(state.mass)
Collider.register("")(Collider)
from .naive import NaiveSimulator
from .cell_list import CellList, DynamicCellList, MaterializedCellList, NeighborList
# from .sweep_and_prune import SweeAPrune
__all__ = [
"Collider",
"NaiveSimulator",
"CellList",
"DynamicCellList",
"MaterializedCellList",
"NeighborList",
]