# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Simulation domains and boundary-condition implementations."""
from __future__ import annotations
import jax
import jax.numpy as jnp
from abc import ABC
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Tuple
from functools import partial
from ..factory import Factory
try: # Python 3.11+
from typing import Self
except ImportError: # pragma: no cover
from typing_extensions import Self
if TYPE_CHECKING: # pragma: no cover
from ..state import State
from ..system import System
[docs]
@jax.tree_util.register_dataclass
@dataclass(slots=True)
class Domain(Factory, ABC):
"""
The base interface for defining the simulation domain and the effect of its boundary conditions.
The `Domain` class defines how:
- Relative displacement vectors between particles are calculated.
- Particles' positions are "shifted" or constrained to remain within the
defined simulation boundaries based on the boundary condition type.
Example
-------
To define a custom domain, inherit from `Domain` and implement its abstract methods:
>>> @Domain.register("my_custom_domain")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class MyCustomDomain(Domain):
...
"""
box_size: jax.Array
"""Length of the simulation domain along each dimension."""
anchor: jax.Array
"""Anchor position (minimum coordinate) of the simulation domain."""
@property
def periodic(self) -> bool:
"""Whether the domain enforces periodic boundary conditions."""
return False
[docs]
@classmethod
def Create(
cls,
dim: int,
box_size: Optional[jax.Array] = None,
anchor: Optional[jax.Array] = None,
) -> Self:
"""
Default factory method for the Domain class.
This method constructs a new Domain instance with a box-shaped domain
of the given dimensionality. If `box_size` or `anchor` are not provided,
they are initialized to default values.
Parameters
----------
dim : int
The dimensionality of the domain (e.g., 2, 3).
box_size : jax.Array, optional
The size of the domain along each dimension. If not provided,
defaults to an array of ones with shape `(dim,)`.
anchor : jax.Array, optional
The anchor (origin) of the domain. If not provided,
defaults to an array of zeros with shape `(dim,)`.
Returns
-------
Domain
A new instance of the Domain subclass with the specified
or default configuration.
Raises
------
AssertionError
If `box_size` and `anchor` do not have the same shape.
"""
if box_size is None:
box_size = jnp.ones(dim, dtype=float)
box_size = jnp.asarray(box_size, dtype=float)
if anchor is None:
anchor = jnp.zeros_like(box_size, dtype=float)
anchor = jnp.asarray(anchor, dtype=float)
assert box_size.shape == anchor.shape
return cls(box_size=box_size, anchor=anchor)
[docs]
@staticmethod
@partial(jax.jit, inline=True)
def displacement(ri: jax.Array, rj: jax.Array, system: "System") -> jax.Array:
r"""
Computes the displacement vector between two particles :math:`r_i` and :math:`r_j`,
considering the domain's boundary conditions.
Parameters
----------
ri : jax.Array
Position vector of the first particle :math:`r_i`. Shape `(dim,)`.
rj : jax.Array
Position vector of the second particle :math:`r_j`. Shape `(dim,)`.
system : System
The configuration of the simulation, containing the `domain` instance.
Returns
-------
jax.Array
The displacement vector :math:`r_{ij} = r_i - r_j`,
adjusted for boundary conditions. Shape `(dim,)`.
Example
-------
>>> rij = system.domain.displacement(ri, rj, system)
"""
return ri - rj
[docs]
@staticmethod
@partial(jax.jit, donate_argnames=("state", "system"), inline=True)
def apply(state: "State", system: "System") -> Tuple["State", "System"]:
"""
Applies boundary conditions during the simulation step.
This method updates the `state` based on the domain's rules, ensuring
particles handle interactions at boundaries appropriately (e.g., reflection).
Parameters
----------
state : State
The current state of the simulation.
system : System
The configuration of the simulation.
Returns
-------
Tuple[State, System]
A tuple containing the updated `State` object adjusted by the boundary conditions and the `System` object.
Note
-----
- Periodic boundary conditions dont require wrapping of the coordinates during time stepping,
but reflective boundaries require changing positions and velocities. To wrap positions
for periodic boundaries so they are displayed correctly when saving, and other algorithms
use the shift method.
- This method donates state and system
Example
-------
>>> state, system = system.domain.apply(state, system)
"""
return state, system
[docs]
@staticmethod
@partial(jax.jit, inline=True)
def shift(state: "State", system: "System") -> Tuple["State", "System"]:
"""
This method updates the `state` based on the domain's rules, ensuring
particles remain within the simulation box or handle interactions
at boundaries appropriately (e.g., reflection, wrapping).
Parameters
----------
state : State
The current state of the simulation.
system : System
The configuration of the simulation.
Returns
-------
Tuple[State, System]
A tuple containing the updated `State` object adjusted by the boundary conditions and the `System` object.
Example
-------
>>> state, system = system.domain.shift(state, system)
"""
return state, system
from .free import FreeDomain
from .periodic import PeriodicDomain
from .reflect import ReflectDomain
from .reflect_sphere import ReflectSphereDomain
__all__ = [
"Domain",
"FreeDomain",
"PeriodicDomain",
"ReflectDomain",
"ReflectSphereDomain",
]