# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Simulation domains and boundary-condition implementations."""
from __future__ import annotations
import jax
import jax.numpy as jnp
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional, Tuple
from ..factory import Factory
try: # Python 3.11+
from typing import Self
except ImportError: # pragma: no cover - fallback for older Python
from typing_extensions import Self # type: ignore
if TYPE_CHECKING: # pragma: no cover
from ..state import State
from ..system import System
[docs]
@jax.tree_util.register_dataclass
@dataclass(slots=True, frozen=True)
class Domain(Factory, ABC):
"""
The base interface for defining the simulation domain and the effect of its boundary conditions.
The `Domain` class defines how:
- Relative displacement vectors between particles are calculated.
- Particles' positions are "shifted" or constrained to remain within the
defined simulation boundaries based on the boundary condition type.
Example
-------
To define a custom domain, inherit from `Domain` and implement its abstract methods:
>>> @Domain.register("my_custom_domain")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True, frozen=True)
>>> class MyCustomDomain(Domain):
...
"""
box_size: jax.Array
"""Length of the simulation domain along each dimension."""
anchor: jax.Array
"""Anchor position (minimum coordinate) of the simulation domain."""
periodic: ClassVar[bool] = False
"""Whether the domain enforces periodic boundary conditions."""
[docs]
@classmethod
def Create(
cls,
dim: int,
box_size: Optional[jax.Array] = None,
anchor: Optional[jax.Array] = None,
) -> Self:
"""
Default factory method for the Domain class.
This method constructs a new Domain instance with a box-shaped domain
of the given dimensionality. If `box_size` or `anchor` are not provided,
they are initialized to default values.
Parameters
----------
dim : int
The dimensionality of the domain (e.g., 2, 3).
box_size : jax.Array, optional
The size of the domain along each dimension. If not provided,
defaults to an array of ones with shape `(dim,)`.
anchor : jax.Array, optional
The anchor (origin) of the domain. If not provided,
defaults to an array of zeros with shape `(dim,)`.
Returns
-------
Domain
A new instance of the Domain subclass with the specified
or default configuration.
Raises
------
AssertionError
If `box_size` and `anchor` do not have the same shape.
"""
if box_size is None:
box_size = jnp.ones(dim, dtype=float)
box_size = jnp.asarray(box_size, dtype=float)
if anchor is None:
anchor = jnp.zeros_like(box_size, dtype=float)
anchor = jnp.asarray(anchor, dtype=float)
assert box_size.shape == anchor.shape
return cls(box_size=box_size, anchor=anchor)
[docs]
@staticmethod
@abstractmethod
@jax.jit
def displacement(ri: jax.Array, rj: jax.Array, system: "System") -> jax.Array:
r"""
Computes the displacement vector between two particles :math:`r_i` and :math:`r_j`,
considering the domain's boundary conditions.
Parameters
----------
ri : jax.Array
Position vector of the first particle :math:`r_i`. Shape `(dim,)`.
rj : jax.Array
Position vector of the second particle :math:`r_j`. Shape `(dim,)`.
system : System
The configuration of the simulation, containing the `domain` instance.
Returns
-------
jax.Array
The displacement vector :math:`r_{ij} = r_i - r_j`,
adjusted for boundary conditions. Shape `(dim,)`.
Example
-------
>>> rij = system.domain.displacement(ri, rj, system)
"""
raise NotImplementedError
[docs]
@staticmethod
@abstractmethod
@jax.jit
def shift(state: "State", system: "System") -> Tuple["State", "System"]:
"""
Applies boundary conditions to particles state.
This method updates the `state` based on the domain's rules, ensuring
particles remain within the simulation box or handle interactions
at boundaries appropriately (e.g., reflection, wrapping).
Parameters
----------
state : State
The current state of the simulation.
system : System
The configuration of the simulation.
Returns
-------
Tuple[State, System]
A tuple containing the updated `State` object adjusted by the boundary conditions and the `System` object.
Example
-------
>>> state, system = system.domain.shift(state, system)
"""
raise NotImplementedError
from .free import FreeDomain
from .periodic import PeriodicDomain
from .reflect import ReflectDomain
__all__ = ["Domain", "FreeDomain", "PeriodicDomain", "ReflectDomain"]