Source code for jaxdem.collider
# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Interface for defining colliders. Colliders perform contact detection and compute forces.
"""
import jax
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Tuple, TYPE_CHECKING
from .factory import Factory
if TYPE_CHECKING:
from .state import State
from .system import System
[docs]
@jax.tree_util.register_dataclass
@dataclass(slots=True)
class Collider(Factory["Collider"], ABC):
r"""
The base interface for defining how contact detection and force computations are performed in a simulation.
Concrete subclasses of `Collider` implement the specific algorithms for calculating the interactions.
Notes
-----
All abstract methods in `Collider` (and their implementations in subclasses)
must be compatible with JAX transformations (`jax.jit`, `jax.vmap`, etc.).
They are expected to work seamlessly in both 2D and 3D simulations.
Self-interaction (i.e., calling the force/energy computation for `i=j`) is allowed,
and the underlying `force_model` is responsible for correctly handling or
ignoring this case.
Example
-------
To define a custom collider, inherit from `Collider`, register it and implement its abstract methods:
>>> @Collider.register("CustomCollider")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class CustomCollider(Collider):
...
Then, instantiate it:
>>> jaxdem.Collider.create("CustomCollider", **custom_collider_kw)
"""
[docs]
@staticmethod
@abstractmethod
@jax.jit
def compute_force(state: "State", system: "System") -> Tuple["State", "System"]:
"""
Abstract method to compute the total force acting on each particle in the simulation.
Implementations should calculate inter-particle forces based on the current
`state` and `system` configuration, then update the `accel` attribute of the
`state` object with the resulting total acceleration for each particle.
Parameters
----------
state : State
The current state of the simulation.
system : System
The configuration of the simulation.
Returns
-------
Tuple[State, System]
A tuple containing the updated `State` object (with computed accelerations)
and the `System` object.
Raises
------
NotImplementedError
This is an abstract method and must be implemented by subclasses.
Example
-------
This method is typically called internally by the `System`'s step function:
>>> state, system = system.collider.compute_force(state, system)
"""
raise NotImplemented
[docs]
@staticmethod
@abstractmethod
@jax.jit
def compute_potential_energy(state: "State", system: "System") -> jax.Array:
"""
Abstract method to compute the total potential energy of the system.
Implementations should calculate the sum of all potential energies
present in the system based on the current `state` and `system` configuration.
Parameters
----------
state : State
The current state of the simulation.
system : System
The configuration of the simulation.
Returns
-------
jax.Array
A scalar JAX array representing the total potential energy of each particle.
Raises
------
NotImplementedError
This is an abstract method and must be implemented by subclasses.
Example
-------
>>> total_potential_energy = system.collider.compute_potential_energy(state, system)
>>> print(f"Total potential energy per particle: {total_potential_energy:.4f}")
"""
raise NotImplemented
[docs]
@Collider.register("naive")
@jax.tree_util.register_dataclass
@dataclass(slots=True)
class NaiveSimulator(Collider):
r"""
Implementation that computes forces and potential energies using a naive :math:`O(N^2)` all-pairs interaction loop.
Notes
-----
Due to its :math:`O(N^2)` complexity, `NaiveSimulator` is suitable for simulations
with a relatively small number of particles. For larger systems, a more
efficient spatial partitioning collider should be used. However, thhis collider should be the fastest
option for small systems (:math:`<10^3` spheres)
Example
-------
"""
[docs]
@staticmethod
@jax.jit
def compute_force(state: "State", system: "System") -> Tuple["State", "System"]:
r"""
Computes the total force on each particle using a naive :math:`O(N^2)` all-pairs loop.
This method iterates over all particle pairs (i, j) and sums the forces
computed by the `system.force_model`.
Parameters
----------
state : State
The current state of the simulation.
system : System
The configuration of the simulation.
Returns
-------
Tuple[State, System]
A tuple containing the updated `State` object with computed accelerations
and the `System` object.
"""
Range = jax.lax.iota(dtype=int, size=state.N)
state.accel += (
jax.vmap(
lambda i: jax.vmap(
lambda j: system.force_model.force(i, j, state, system)
)(Range).sum(axis=0)
)(Range)
/ state.mass[:, None]
)
return state, system
[docs]
@staticmethod
@jax.jit
def compute_potential_energy(state: "State", system: "System") -> jax.Array:
r"""
Computes the total potential energy of the system using a naive :math:`O(N^2)` all-pairs loop.
This method sums the potential energy contributions from all particle pairs (i, j)
as computed by the `system.force_model`.
Parameters
----------
state : State
The current state of the simulation.
system : System
The configuration of the simulation.
Returns
-------
jax.Array
A scalar JAX array representing the total potential energy of the system.
"""
Range = jax.lax.iota(dtype=int, size=state.N)
return jax.vmap(
lambda i: jax.vmap(
lambda j: system.force_model.energy(i, j, state, system)
)(Range).sum(axis=0)
)(Range)