Source code for jaxdem.forces.law_combiner
# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Composite force model that sums multiple force laws."""
from __future__ import annotations
import jax
from dataclasses import dataclass, field
from typing import Tuple
from . import ForceModel
[docs]
@ForceModel.register("lawcombiner")
@jax.tree_util.register_dataclass
@dataclass(slots=True, frozen=True)
class LawCombiner(ForceModel):
"""Sum a tuple of elementary force laws."""
required_material_properties: Tuple[str, ...] = field(
default=(), metadata={"static": True}
)
laws: Tuple["ForceModel", ...] = field(default=(), metadata={"static": True})
def __post_init__(self) -> None:
object.__setattr__(
self,
"required_material_properties",
tuple(
sorted({p for lw in self.laws for p in lw.required_material_properties})
),
)
[docs]
@staticmethod
@jax.jit
def force(i, j, state, system):
return jax.tree.reduce(
lambda a, b: a + b,
tuple(law.force(i, j, state, system) for law in system.force_model.laws),
)
[docs]
@staticmethod
@jax.jit
def energy(i, j, state, system):
e = 0.0
for law in system.force_model.laws:
e = e + law.energy(i, j, state, system)
return e
__all__ = ["LawCombiner"]