Source code for jaxdem.forces.router

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Force model router selecting laws based on species pairs."""

from __future__ import annotations

import jax

from dataclasses import dataclass, field
from typing import Tuple

from . import ForceModel
from .law_combiner import LawCombiner


[docs] @jax.tree_util.register_dataclass @dataclass(slots=True, frozen=True) class ForceRouter(ForceModel): """Static species-to-force lookup table.""" table: Tuple[Tuple["ForceModel", ...], ...] = field(default=(())) required_material_properties: Tuple[str, ...] = field( default=(), metadata={"static": True} ) def __post_init__(self) -> None: req = { p for row in self.table for law in row for p in law.required_material_properties } object.__setattr__(self, "required_material_properties", tuple(sorted(req)))
[docs] @staticmethod def from_dict(S: int, mapping: dict[Tuple[int, int], ForceModel]): empty = LawCombiner() # zero-force default m = [[empty for _ in range(S)] for _ in range(S)] for (a, b), law in mapping.items(): m[a][b] = m[b][a] = law return ForceRouter(table=tuple(tuple(r) for r in m))
[docs] @staticmethod @jax.jit def force(i, j, state, system): si, sj = int(state.species_id[i]), int(state.species_id[j]) law = system.force_model.table[si][sj] return law.force(i, j, state, system)
[docs] @staticmethod @jax.jit def energy(i, j, state, system): si, sj = int(state.species_id[i]), int(state.species_id[j]) law = system.force_model.table[si][sj] return law.energy(i, j, state, system)
__all__ = ["ForceRouter"]