# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Interface for defining force laws and their corresponding potential energy.
"""
import jax
import jax.numpy as jnp
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Tuple
from .factory import Factory
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .state import State
from .system import System
[docs]
@jax.tree_util.register_dataclass
@dataclass(slots=True)
class ForceModel(Factory["ForceModel"], ABC):
"""
Abstract base class for defining inter-particle force laws and their potential energies.
Concrete subclasses implement specific force and energy models, such as
linear springs, Hertzian contacts, etc.
Notes
-----
- Implementations should be JIT-compilable.
- The :meth:`force` and :meth:`energy` methods should correctly handle the
case where `i` and `j` refer to the same particle (i.e., `i == j`).
There is no guarantee that self-interaction calls will not occur.
Example
-------
To define a custom force model, inherit from :class:`ForceModel` and implement
its abstract methods:
>>> @ForceModel.register("myCustomForce")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class MyCustomForce(ForceModel):
...
"""
required_material_properties: Tuple[str, ...] = field(
default=(), metadata={"static": True}
)
"""
A static tuple of strings specifying the material properties required by this force model.
These properties (e.g., 'young_eff', 'restitution') must be present in the
:attr:`System.mat_table` for the model to function correctly. This is used
for validation.
"""
laws: Tuple["ForceModel", ...] = field(default=(), metadata={"static": True})
"""
A static tuple of other :class:`ForceModel` instances that compose this force model.
This allows for creating composite force models (e.g., a total force being
the sum of a spring force and a damping force).
"""
[docs]
@staticmethod
@abstractmethod
@jax.jit
def force(i: int, j: int, state: "State", system: "System") -> jax.Array:
"""
Compute the force vector acting on particle :math:`i` due to particle :math:`j`.
Parameters
----------
i : int
Index of the first particle (on which the force is acting).
j : int
Index of the second particle (which is exerting the force).
state : State
Current state of the simulation.
system : System
Simulation system configuration.
Returns
-------
jax.Array
Force vector acting on particle :math:`i` due to particle :math:`j`. Shape `(dim,)`.
Raises
------
NotImplementedError
This is an abstract method and must be implemented by subclasses.
Example
-------
This method is called internally by a `Collider` when computing
total forces:
>>> force_on_particle_0_from_1 = system.force_model.force(0, 1, state, system)
"""
raise NotImplementedError
[docs]
@staticmethod
@abstractmethod
@jax.jit
def energy(i: int, j: int, state: "State", system: "System") -> jax.Array:
"""
Compute the potential energy of the interaction between particle :math:`i` and particle :math:`j`.
Parameters
----------
i : int
Index of the first particle.
j : int
Index of the second particle.
state : State
Current state of the simulation.
system : System
Simulation system configuration.
Returns
-------
jax.Array
Scalar JAX array representing the potential energy of the interaction
between particles :math:`i` and :math:`j`.
Raises
------
NotImplementedError
This is an abstract method and must be implemented by subclasses.
Example
-------
This method is typically called internally by a `Collider` when computing
the total potential energy of the system:
>>> energy_0_1 = system.force_model.energy(0, 1, state, system)
"""
raise NotImplementedError
[docs]
@ForceModel.register("spring")
@jax.tree_util.register_dataclass
@dataclass(slots=True)
class SpringForce(ForceModel):
"""
A `ForceModel` implementation for a linear spring-like interaction between particles.
Notes
-----
- The 'effective Young's modulus' (:math:`k_{eff,\\; ij}`) is retrieved from the
:attr:`System.mat_table.young_eff` based on the material IDs of the interacting particles.
- The force is zero if :math:`i == j`.
- A small epsilon is added to the squared distance (:math:`r^2`) before taking the square root
to prevent division by zero or NaN issues when particles are perfectly co-located.
The penetration :math:`\delta` (overlap) between two particles :math:`i` and :math:`j` is:
.. math::
\\delta = (R_i + R_j) - r
where :math:`R_i` and :math:`R_j` are the radii of particles :math:`i` and :math:`j` respectively,
and :math:`r = ||r_{ij}||` is the distance between their centers.
The scalar overlap :math:`s` is defined as:
.. math::
s = \\max\\left(0, \\frac{R_i + R_j}{r} - 1\\right)
The force :math:`F_{ij}` acting on particle :math:`i` due to particle :math:`j` is:
.. math::
F_{ij} = k_{eff,\\; ij} \\cdot s \\cdot r_{ij}
The potential energy :math:`E_{ij}` of the interaction is:
.. math::
E_{ij} = \\frac{1}{2} k_{eff,\\; ij} \\cdot s^2
where :math:`k_{eff,\\; ij}` is the effective Young's modulus for the particle pair.
Example
-------
To use :class:`SpringForce` in a simulation, specify it as the `force_model`
when creating your :class:`System`:
>>> system = jaxdem.System.create(dim=3, force_model_type="spring", force_model_kw={})
For this force model, the typical :class:`MaterialMatchmaker`: type is "harmonic".
"""
required_material_properties: Tuple[str, ...] = field(
default=("young_eff",), metadata={"static": True}
)
[docs]
@staticmethod
@jax.jit
def force(i: int, j: int, state: "State", system: "System") -> jax.Array:
"""
Compute linear spring-like interaction force acting on particle :math:`i` due to particle :math:`j`.
Returns zero when :math:`i = j`.
Parameters
----------
i : int
Index of the first particle.
j : int
Index of the second particle.
state : State
Current state of the simulation.
system : System
Simulation system configuration.
Returns
-------
jax.Array
Force vector acting on particle :math:`i` due to particle :math:`j`.
"""
mi, mj = state.mat_id[i], state.mat_id[j]
k = system.mat_table.young_eff[mi, mj]
rij = system.domain.displacement(state.pos[i], state.pos[j], system)
r2 = jnp.dot(rij, rij)
r = jnp.sqrt(
r2 + jnp.finfo(state.pos.dtype).eps
) # Adding epsilon for numerical stability
s = jnp.maximum(0.0, (state.rad[i] + state.rad[j]) / r - 1.0)
return k * s * rij
[docs]
@staticmethod
@jax.jit
def energy(i: int, j: int, state: "State", system: "System") -> jax.Array:
"""
Compute linear spring-like interaction potential energy between particle :math:`i` and particle :math:`j`.
Returns zero when :math:`i = j`.
Parameters
----------
i : int
Index of the first particle.
j : int
Index of the second particle.
state : State
Current state of the simulation.
system : System
Simulation system configuration.
Returns
-------
jax.Array
Scalar JAX array representing the potential energy of the interaction
between particles :math:`i` and :math:`j`.
"""
mi, mj = state.mat_id[i], state.mat_id[j]
k = system.mat_table.young_eff[mi, mj]
rij = system.domain.displacement(state.pos[i], state.pos[j], system)
r2 = jnp.dot(rij, rij)
r = jnp.sqrt(
r2 + jnp.finfo(state.pos.dtype).eps
) # Adding epsilon for numerical stability
s = jnp.maximum(0.0, (state.rad[i] + state.rad[j]) / r - 1.0)
return 0.5 * k * s**2