Source code for jaxdem.forces

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Force-law interfaces and concrete implementations."""

from __future__ import annotations

import jax

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Tuple

from ..factory import Factory

if TYPE_CHECKING:  # pragma: no cover
    from ..state import State
    from ..system import System


[docs] @jax.tree_util.register_dataclass @dataclass(slots=True, frozen=True) class ForceModel(Factory, ABC): """ Abstract base class for defining inter-particle force laws and their corresponding potential energies. Concrete subclasses implement specific force and energy models, such as linear springs, Hertzian contacts, etc. Notes ----- - The :meth:`force` and :meth:`energy` methods should correctly handle the case where `i` and `j` refer to the same particle (i.e., `i == j`). There is no guarantee that self-interaction calls will not occur. Example ------- To define a custom force model, inherit from :class:`ForceModel` and implement its abstract methods: >>> @ForceModel.register("myCustomForce") >>> @jax.tree_util.register_dataclass >>> @dataclass(slots=True, frozen=True) >>> class MyCustomForce(ForceModel): ... """ required_material_properties: Tuple[str, ...] = field( default=(), metadata={"static": True} ) """ A static tuple of strings specifying the material properties required by this force model. These properties (e.g., 'young_eff', 'restitution', ...) must be present in the :attr:`System.mat_table` for the model to function correctly. This is used for validation. """ laws: Tuple["ForceModel", ...] = field(default=(), metadata={"static": True}) """ A static tuple of other :class:`ForceModel` instances that compose this force model. This allows for creating composite force models (e.g., a total force being the sum of a spring force and a damping force). """
[docs] @staticmethod @abstractmethod @jax.jit def force(i: int, j: int, state: "State", system: "System") -> jax.Array: """ Compute the force vector acting on particle :math:`i` due to particle :math:`j`. Parameters ---------- i : int Index of the first particle (on which the force is acting). j : int Index of the second particle (which is exerting the force). state : State Current state of the simulation. system : System Simulation system configuration. Returns ------- jax.Array Force vector acting on particle :math:`i` due to particle :math:`j`. Shape `(dim,)`. """ raise NotImplementedError
[docs] @staticmethod @abstractmethod @jax.jit def energy(i: int, j: int, state: "State", system: "System") -> jax.Array: """ Compute the potential energy of the interaction between particle :math:`i` and particle :math:`j`. Parameters ---------- i : int Index of the first particle. j : int Index of the second particle. state : State Current state of the simulation. system : System Simulation system configuration. Returns ------- jax.Array Scalar JAX array representing the potential energy of the interaction between particles :math:`i` and :math:`j`. """ raise NotImplementedError
from .law_combiner import LawCombiner from .router import ForceRouter from .spring import SpringForce __all__ = ["ForceModel", "LawCombiner", "ForceRouter", "SpringForce"]