Source code for jaxdem.forces.router
# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Force model router selecting laws based on species pairs."""
from __future__ import annotations
import jax
from dataclasses import dataclass, field
from typing import Tuple
from . import ForceModel
from .law_combiner import LawCombiner
[docs]
@jax.tree_util.register_dataclass
@dataclass(slots=True, frozen=True)
class ForceRouter(ForceModel):
"""Static species-to-force lookup table."""
table: Tuple[Tuple["ForceModel", ...], ...] = field(default=(()))
required_material_properties: Tuple[str, ...] = field(
default=(), metadata={"static": True}
)
def __post_init__(self) -> None:
req = {
p
for row in self.table
for law in row
for p in law.required_material_properties
}
object.__setattr__(self, "required_material_properties", tuple(sorted(req)))
[docs]
@staticmethod
def from_dict(S: int, mapping: dict[Tuple[int, int], ForceModel]):
empty = LawCombiner() # zero-force default
m = [[empty for _ in range(S)] for _ in range(S)]
for (a, b), law in mapping.items():
m[a][b] = m[b][a] = law
return ForceRouter(table=tuple(tuple(r) for r in m))
[docs]
@staticmethod
@jax.jit
def force(i, j, state, system):
si, sj = int(state.species_id[i]), int(state.species_id[j])
law = system.force_model.table[si][sj]
return law.force(i, j, state, system)
[docs]
@staticmethod
@jax.jit
def energy(i, j, state, system):
si, sj = int(state.species_id[i]), int(state.species_id[j])
law = system.force_model.table[si][sj]
return law.energy(i, j, state, system)
__all__ = ["ForceRouter"]