Source code for jaxdem.domains.reflect_sphere
# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Reflective boundary-condition domain."""
from __future__ import annotations
import jax
from dataclasses import dataclass
from typing import TYPE_CHECKING, Tuple
from functools import partial
from . import Domain
if TYPE_CHECKING: # pragma: no cover
from ..state import State
from ..system import System
[docs]
@Domain.register("reflectsphere")
@jax.tree_util.register_dataclass
@dataclass(slots=True)
class ReflectSphereDomain(Domain):
"""
A `Domain` implementation that enforces reflective boundary conditions only for spheres.
We have this dedicated version for performance reasons.
Particles that attempt to move beyond the defined `box_size` will have their
positions reflected back into the box and their velocities reversed in the
direction normal to the boundary.
Notes
-----
- The reflection occurs at the boundaries defined by `anchor` and `anchor + box_size`.
"""
[docs]
@staticmethod
@partial(jax.jit, donate_argnames=("state", "system"), inline=True)
@partial(jax.named_call, name="ReflectSphereDomain.apply")
def apply(state: "State", system: "System") -> Tuple["State", "System"]:
r"""
Applies reflective boundary conditions to particles.
Particles are checked against the domain boundaries.
If a particle attempts to move beyond a boundary, its position is reflected
back into the box, and its velocity component normal to that boundary is reversed.
.. math::
l &= a + R \\\\
u &= a + B - R \\\\
v' &= \\begin{cases} -v & \\text{if } r < l \\text{ or } r > u \\\\ v & \\text{otherwise} \\end{cases} \\\\
r' &= \\begin{cases} 2l - r & \\text{if } r < l \\\\ r & \\text{otherwise} \\end{cases} \\\\
r'' &= \\begin{cases} 2u - r' & \\text{if } r' > u \\\\ r' & \\text{otherwise} \\end{cases}
r = r''
where:
- :math:`r` is the current particle position (:attr:`jaxdem.State.pos`)
- :math:`v` is the current particle velocity (:attr:`jaxdem.State.vel`)
- :math:`a` is the domain anchor (:attr:`Domain.anchor`)
- :math:`B` is the domain box size (:attr:`Domain.box_size`)
- :math:`R` is the particle radius (:attr:`jaxdem.State.rad`)
- :math:`l` is the lower boundary for the particle center
- :math:`u` is the upper boundary for the particle center
TO DO: Ensure correctness when adding different types of shapes and angular vel
Parameters
----------
state : State
The current state of the simulation.
system : System
The configuration of the simulation.
Returns
-------
Tuple[State, System]
The updated `State` object with reflected positions and velocities,
and the `System` object.
Note
-----
- This method donates state and system
- Only works for states with *ONLY* spheres.
"""
pos = state.pos
lo = system.domain.anchor + state.rad[:, None]
hi = system.domain.anchor + system.domain.box_size - state.rad[:, None]
# over_lo = jnp.maximum(0.0, lo - state.pos)
over_lo = lo - pos
over_lo *= over_lo > 0
# over_hi = jnp.maximum(0.0, state.pos - hi)
over_hi = pos - hi
over_hi *= over_hi > 0
# hit = jnp.logical_or(over_lo > 0, over_hi > 0)
hit = ((over_lo > 0) + (over_hi > 0)) > 0
sign = 1.0 - 2.0 * (hit > 0)
state.pos_c += 2.0 * (over_lo - over_hi)
state.vel *= sign
# state.angVel *= sign # Is this correct?
return state, system
__all__ = ["ReflectSphereDomain"]