Source code for jaxdem.forces.wca

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project - https://github.com/cdelv/JaxDEM

from __future__ import annotations

import jax
import jax.numpy as jnp

from dataclasses import dataclass
from typing import TYPE_CHECKING, Tuple
from functools import partial

from . import ForceModel

if TYPE_CHECKING:  # pragma: no cover
    from ..state import State
    from ..system import System


[docs] @ForceModel.register("wca") @partial(jax.tree_util.register_dataclass, drop_fields=["required_material_properties"]) @dataclass(slots=True) class WCA(ForceModel): r""" Weeks-Chandler-Andersen (WCA) purely repulsive Lennard-Jones interaction. Uses material-pair parameter: - epsilon_eff[mi, mj] The length scale :math:`\sigma_{ij}` is derived from particle radii (like `spring.py`): .. math:: \sigma_{ij} = R_i + R_j Potential (for r < r_c = 2^(1/6) sigma): U(r) = 4 eps [(sigma/r)^12 - (sigma/r)^6] + eps else: U(r) = 0 Force: F_vec = 24 eps (2 (sigma/r)^12 - (sigma/r)^6) * (1/r^2) * r_ij """
[docs] @staticmethod @partial(jax.jit, inline=True) @partial(jax.named_call, name="WCA.force") def force( i: int, j: int, pos: jax.Array, state: State, system: System ) -> Tuple[jax.Array, jax.Array]: mi, mj = state.mat_id[i], state.mat_id[j] eps = system.mat_table.epsilon_eff[mi, mj] sig = state.rad[i] + state.rad[j] rij = system.domain.displacement(pos[i], pos[j], system) r2 = jnp.sum(rij * rij, axis=-1) # (...) r2 = jnp.where(r2 == 0, jnp.ones_like(r2), r2) sig2 = sig * sig inv_r2 = 1.0 / r2 sr2 = sig2 * inv_r2 sr6 = sr2 * sr2 * sr2 sr12 = sr6 * sr6 # cutoff: r_c = 2^(1/6) sigma => r_c^2 = 2^(1/3) sigma^2 rc2 = (2.0 ** (1.0 / 3.0)) * sig2 active = r2 < rc2 not_self = j != i mask = active & not_self coeff = 24.0 * eps * inv_r2 * (2.0 * sr12 - sr6) f = (coeff[..., None] * rij) * mask[..., None] return f, jnp.zeros_like(state.angVel[i])
[docs] @staticmethod @partial(jax.jit, inline=True) @partial(jax.named_call, name="WCA.energy") def energy( i: int, j: int, pos: jax.Array, state: State, system: System ) -> jax.Array: mi, mj = state.mat_id[i], state.mat_id[j] eps = system.mat_table.epsilon_eff[mi, mj] sig = state.rad[i] + state.rad[j] rij = system.domain.displacement(pos[i], pos[j], system) r2 = jnp.sum(rij * rij, axis=-1) r2 = jnp.where(r2 == 0, jnp.ones_like(r2), r2) sig2 = sig * sig inv_r2 = 1.0 / r2 sr2 = sig2 * inv_r2 sr6 = sr2 * sr2 * sr2 sr12 = sr6 * sr6 rc2 = (2.0 ** (1.0 / 3.0)) * sig2 active = r2 < rc2 not_self = j != i mask = active & not_self u = 4.0 * eps * (sr12 - sr6) + eps return u * mask
@property def required_material_properties(self) -> Tuple[str, ...]: return ("epsilon_eff",)