Source code for jaxdem.domains.reflect

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Reflective boundary-condition domain."""

from __future__ import annotations

import jax
import jax.numpy as jnp

from dataclasses import dataclass, replace
from typing import TYPE_CHECKING, Tuple

from . import Domain

if TYPE_CHECKING:  # pragma: no cover
    from ..state import State
    from ..system import System


[docs] @Domain.register("reflect") @jax.tree_util.register_dataclass @dataclass(slots=True, frozen=True) class ReflectDomain(Domain): """ A `Domain` implementation that enforces reflective boundary conditions. Particles that attempt to move beyond the defined `box_size` will have their positions reflected back into the box and their velocities reversed in the direction normal to the boundary. Notes ----- - The reflection occurs at the boundaries defined by `anchor` and `anchor + box_size`. """
[docs] @staticmethod @jax.jit def displacement(ri: jax.Array, rj: jax.Array, _: "System") -> jax.Array: r""" Computes the displacement vector between two particles. In a reflective domain, the displacement is simply the direct vector difference. Parameters ---------- ri : jax.Array Position vector of the first particle :math:`r_i`. rj : jax.Array Position vector of the second particle :math:`r_j`. _ : System The system object. Returns ------- jax.Array The direct displacement vector :math:`r_i - r_j`. """ return ri - rj
[docs] @staticmethod @jax.jit def shift(state: "State", system: "System") -> Tuple["State", "System"]: """ Applies reflective boundary conditions to particles. Particles are checked against the domain boundaries. If a particle attempts to move beyond a boundary, its position is reflected back into the box, and its velocity component normal to that boundary is reversed. .. math:: l &= a + R \\\\ u &= a + B - R \\\\ v' &= \\begin{cases} -v & \\text{if } r < l \\text{ or } r > u \\\\ v & \\text{otherwise} \\end{cases} \\\\ r' &= \\begin{cases} 2l - r & \\text{if } r < l \\\\ r & \\text{otherwise} \\end{cases} \\\\ r'' &= \\begin{cases} 2u - r' & \\text{if } r' > u \\\\ r' & \\text{otherwise} \\end{cases} r = r'' where: - :math:`r` is the current particle position (:attr:`jaxdem.State.pos`) - :math:`v` is the current particle velocity (:attr:`jaxdem.State.vel`) - :math:`a` is the domain anchor (:attr:`Domain.anchor`) - :math:`B` is the domain box size (:attr:`Domain.box_size`) - :math:`R` is the particle radius (:attr:`jaxdem.State.rad`) - :math:`l` is the lower boundary for the particle center - :math:`u` is the upper boundary for the particle center TO DO: Ensure correctness when adding different types of shapes and angular vel Parameters ---------- state : State The current state of the simulation. system : System The configuration of the simulation. Returns ------- Tuple[State, System] The updated `State` object with reflected positions and velocities, and the `System` object. """ lo = system.domain.anchor + state.rad[:, None] hi = system.domain.anchor + system.domain.box_size - state.rad[:, None] over_lo = jnp.maximum(0.0, lo - state.pos) over_hi = jnp.maximum(0.0, state.pos - hi) x_ref = state.pos + 2.0 * over_lo - 2.0 * over_hi hit = jnp.logical_or((over_lo > 0), (over_hi > 0)) sign = 1.0 - 2.0 * hit v_ref = state.vel * sign return replace(state, pos=x_ref, vel=v_ref), system
__all__ = ["ReflectDomain"]