Source code for jaxdem.forces.law_combiner

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Composite force model that sums multiple force laws."""

from __future__ import annotations

import jax
import jax.numpy as jnp

from dataclasses import dataclass, field
from typing import Tuple
from functools import partial

from . import ForceModel


[docs] @ForceModel.register("lawcombiner") @jax.tree_util.register_dataclass @dataclass(slots=True) class LawCombiner(ForceModel): """Sum a tuple of elementary force laws.""" laws: Tuple["ForceModel", ...] = field(default=(), metadata={"static": True}) def __post_init__(self) -> None: object.__setattr__( self, "required_material_properties", tuple( sorted({p for lw in self.laws for p in lw.required_material_properties}) ), )
[docs] @staticmethod @jax.jit @partial(jax.named_call, name="LawCombiner.force") def force(i, j, state, system): force = jnp.zeros_like(state.pos[i]) torque = jnp.zeros_like(state.angVel[i]) for law in system.force_model.laws: f, t = law.force(i, j, state, system) force = force + f torque = torque + t return force, torque
[docs] @staticmethod @jax.jit @partial(jax.named_call, name="LawCombiner.energy") def energy(i, j, state, system): e = 0.0 for law in system.force_model.laws: e = e + law.energy(i, j, state, system) return e
__all__ = ["LawCombiner"]