Source code for jaxdem.domains

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Simulation domains and boundary-condition implementations."""

from __future__ import annotations

import jax
import jax.numpy as jnp

from abc import ABC
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Tuple
from functools import partial

from ..factory import Factory

try:  # Python 3.11+
    from typing import Self
except ImportError:  # pragma: no cover
    from typing_extensions import Self

if TYPE_CHECKING:  # pragma: no cover
    from ..state import State
    from ..system import System


[docs] @jax.tree_util.register_dataclass @dataclass(slots=True) class Domain(Factory, ABC): """ The base interface for defining the simulation domain and the effect of its boundary conditions. The `Domain` class defines how: - Relative displacement vectors between particles are calculated. - Particles' positions are "shifted" or constrained to remain within the defined simulation boundaries based on the boundary condition type. Example ------- To define a custom domain, inherit from `Domain` and implement its abstract methods: >>> @Domain.register("my_custom_domain") >>> @jax.tree_util.register_dataclass >>> @dataclass(slots=True) >>> class MyCustomDomain(Domain): ... """ box_size: jax.Array """Length of the simulation domain along each dimension.""" anchor: jax.Array """Anchor position (minimum coordinate) of the simulation domain.""" @property def periodic(self) -> bool: """Whether the domain enforces periodic boundary conditions.""" return False
[docs] @classmethod def Create( cls, dim: int, box_size: Optional[jax.Array] = None, anchor: Optional[jax.Array] = None, ) -> Self: """ Default factory method for the Domain class. This method constructs a new Domain instance with a box-shaped domain of the given dimensionality. If `box_size` or `anchor` are not provided, they are initialized to default values. Parameters ---------- dim : int The dimensionality of the domain (e.g., 2, 3). box_size : jax.Array, optional The size of the domain along each dimension. If not provided, defaults to an array of ones with shape `(dim,)`. anchor : jax.Array, optional The anchor (origin) of the domain. If not provided, defaults to an array of zeros with shape `(dim,)`. Returns ------- Domain A new instance of the Domain subclass with the specified or default configuration. Raises ------ AssertionError If `box_size` and `anchor` do not have the same shape. """ if box_size is None: box_size = jnp.ones(dim, dtype=float) box_size = jnp.asarray(box_size, dtype=float) if anchor is None: anchor = jnp.zeros_like(box_size, dtype=float) anchor = jnp.asarray(anchor, dtype=float) assert box_size.shape == anchor.shape return cls(box_size=box_size, anchor=anchor)
[docs] @staticmethod @partial(jax.jit, inline=True) def displacement(ri: jax.Array, rj: jax.Array, system: "System") -> jax.Array: r""" Computes the displacement vector between two particles :math:`r_i` and :math:`r_j`, considering the domain's boundary conditions. Parameters ---------- ri : jax.Array Position vector of the first particle :math:`r_i`. Shape `(dim,)`. rj : jax.Array Position vector of the second particle :math:`r_j`. Shape `(dim,)`. system : System The configuration of the simulation, containing the `domain` instance. Returns ------- jax.Array The displacement vector :math:`r_{ij} = r_i - r_j`, adjusted for boundary conditions. Shape `(dim,)`. Example ------- >>> rij = system.domain.displacement(ri, rj, system) """ return ri - rj
[docs] @staticmethod @partial(jax.jit, donate_argnames=("state", "system"), inline=True) def apply(state: "State", system: "System") -> Tuple["State", "System"]: """ Applies boundary conditions during the simulation step. This method updates the `state` based on the domain's rules, ensuring particles handle interactions at boundaries appropriately (e.g., reflection). Parameters ---------- state : State The current state of the simulation. system : System The configuration of the simulation. Returns ------- Tuple[State, System] A tuple containing the updated `State` object adjusted by the boundary conditions and the `System` object. Note ----- - Periodic boundary conditions dont require wrapping of the coordinates during time stepping, but reflective boundaries require changing positions and velocities. To wrap positions for periodic boundaries so they are displayed correctly when saving, and other algorithms use the shift method. - This method donates state and system Example ------- >>> state, system = system.domain.apply(state, system) """ return state, system
[docs] @staticmethod @partial(jax.jit, inline=True) def shift(state: "State", system: "System") -> Tuple["State", "System"]: """ This method updates the `state` based on the domain's rules, ensuring particles remain within the simulation box or handle interactions at boundaries appropriately (e.g., reflection, wrapping). Parameters ---------- state : State The current state of the simulation. system : System The configuration of the simulation. Returns ------- Tuple[State, System] A tuple containing the updated `State` object adjusted by the boundary conditions and the `System` object. Example ------- >>> state, system = system.domain.shift(state, system) """ return state, system
from .free import FreeDomain from .periodic import PeriodicDomain from .reflect import ReflectDomain from .reflect_sphere import ReflectSphereDomain __all__ = [ "Domain", "FreeDomain", "PeriodicDomain", "ReflectDomain", "ReflectSphereDomain", ]