Source code for jaxdem.domains.periodic
# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Periodic boundary-condition domain."""
from __future__ import annotations
import jax
import jax.numpy as jnp
from dataclasses import dataclass, replace
from typing import TYPE_CHECKING, Tuple, ClassVar
from . import Domain
if TYPE_CHECKING: # pragma: no cover
from ..state import State
from ..system import System
[docs]
@Domain.register("periodic")
@jax.tree_util.register_dataclass
@dataclass(slots=True, frozen=True)
class PeriodicDomain(Domain):
"""
A `Domain` implementation that enforces periodic boundary conditions.
Particles that move out of one side of the simulation box re-enter from the
opposite side. The displacement vector between particles is computed using the minimum image convention.
Notes
-----
- This domain type is periodic (`periodic = True`).
"""
periodic: ClassVar[bool] = True
[docs]
@staticmethod
@jax.jit
def displacement(ri: jax.Array, rj: jax.Array, system: "System") -> jax.Array:
"""
Computes the minimum image displacement vector between two particles :math:`r_i` and :math:`r_j`.
For periodic boundary conditions, the displacement is calculated as the
shortest vector that connects :math:`r_j` to :math:`r_i`, potentially by crossing
periodic boundaries.
Parameters
----------
ri : jax.Array
Position vector of the first particle :math:`r_i`.
rj : jax.Array
Position vector of the second particle :math:`r_j`.
system : System
The configuration of the simulation, containing the `domain` instance
with `anchor` and `box_size` for periodicity.
Returns
-------
jax.Array
The minimum image displacement vector:
.. math::
& r_{ij} = (r_i - a) - (r_j - a) \\\\
& r_{ij} = r_{ij} - B \\cdot \\text{round}(r_{ij}/B)
where:
- :math:`a` is the domain anchor (:attr:`Domain.anchor`)
- :math:`B` is the domain box size (:attr:`Domain.box_size`)
"""
rij = ri - rj
return rij - system.domain.box_size * jnp.floor(
rij / system.domain.box_size + 0.5
)
[docs]
@staticmethod
@jax.jit
def shift(state: "State", system: "System") -> Tuple["State", "System"]:
"""
Wraps particles back into the primary simulation box.
.. math::
r = r - B \\cdot \\text{floor}((r - a)/B) \\\\
where:
- :math:`a` is the domain anchor (:attr:`Domain.anchor`)
- :math:`B` is the domain box size (:attr:`Domain.box_size`)
Parameters
----------
state : State
The current state of the simulation.
system : System
The configuration of the simulation.
Returns
-------
Tuple[State, System]
The updated `State` object with wrapped particle positions, and the
`System` object.
"""
pos = state.pos - system.domain.box_size * jnp.floor(
(state.pos - system.domain.anchor) / system.domain.box_size
)
state = replace(state, pos=pos)
return state, system
__all__ = ["PeriodicDomain"]