Source code for jaxdem.forceRouter

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Interface for combining force laws and for defining the species forces matrix.
"""

import jax
import jax.numpy as jnp

from dataclasses import dataclass, field
from typing import Tuple

from .forces import ForceModel


[docs] @ForceModel.register("lawcombiner") @jax.tree_util.register_dataclass @dataclass(slots=True) class LawCombiner(ForceModel): """ Sum a tuple of elementary laws. """ required_material_properties: Tuple[str, ...] = field( default=(), metadata={"static": True} ) laws: Tuple["ForceModel", ...] = field(default=(), metadata={"static": True}) def __post_init__(self): object.__setattr__( self, "required_material_properties", tuple( sorted({p for lw in self.laws for p in lw.required_material_properties}) ), ) # change to tree_map + reduce
[docs] @staticmethod @jax.jit def force(i, j, state, system): out = jnp.zeros_like(state.pos[0]) for law in system.force_model.laws: out = out + law.force(i, j, state, system) return out
# change to tree_map + reduce
[docs] @staticmethod @jax.jit def energy(i, j, state, system): e = 0.0 for law in system.force_model.laws: e = e + law.energy(i, j, state, system) return e
[docs] @jax.tree_util.register_dataclass @dataclass(slots=True) class ForceRouter(ForceModel): """ Static (S×S) table that maps species pairs to a ForceModel. """ table: Tuple[Tuple["ForceModel", ...], ...] = field(default=(())) required_material_properties: Tuple[str, ...] = field( default=(), metadata={"static": True} ) def __post_init__(self): req = { p for row in self.table for law in row for p in law.required_material_properties } object.__setattr__(self, "required_material_properties", tuple(sorted(req)))
[docs] @staticmethod def from_dict(S: int, mapping: dict[Tuple[int, int], ForceModel]): empty = LawCombiner() # zero-force default m = [[empty for _ in range(S)] for _ in range(S)] for (a, b), law in mapping.items(): m[a][b] = m[b][a] = law return ForceRouter(table=tuple(tuple(r) for r in m))
[docs] @staticmethod @jax.jit def force(i, j, state, system): # deal with table warning si, sj = int(state.species_id[i]), int(state.species_id[j]) law = system.force_model.table[si][sj] return law.force(i, j, state, system)
[docs] @staticmethod @jax.jit def energy(i, j, state, system): si, sj = int(state.species_id[i]), int(state.species_id[j]) law = system.force_model.table[si][sj] return law.energy(i, j, state, system)