Source code for jaxdem.domains.free
# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Unbounded (free) simulation domain."""
from __future__ import annotations
import jax
import jax.numpy as jnp
from dataclasses import dataclass, replace
from typing import TYPE_CHECKING, Tuple
from . import Domain
if TYPE_CHECKING: # pragma: no cover
from ..state import State
from ..system import System
[docs]
@Domain.register("free")
@jax.tree_util.register_dataclass
@dataclass(slots=True, frozen=True)
class FreeDomain(Domain):
"""
A `Domain` implementation representing an unbounded, "free" space.
In a `FreeDomain`, there are no explicit boundary conditions applied to
particles. Particles can move indefinitely in any direction, and the
concept of a "simulation box" is only used to define the bounding box of the system.
Notes
-----
- The `box_size` and `anchor` attributes are dynamically updated in
the `shift` method to encompass all particles. Some hashing tools require the domain size.
"""
[docs]
@staticmethod
@jax.jit
def displacement(ri: jax.Array, rj: jax.Array, _: "System") -> jax.Array:
r"""
Computes the displacement vector between two particles.
In a free domain, the displacement is simply the direct vector difference
between the particle positions.
Parameters
----------
ri : jax.Array
Position vector of the first particle :math:`r_i`.
rj : jax.Array
Position vector of the second particle :math:`r_j`.
_ : System
The system object.
Returns
-------
jax.Array
The direct displacement vector :math:`r_i - r_j`.
"""
return ri - rj
[docs]
@staticmethod
@jax.jit
def shift(state: "State", system: "System") -> Tuple["State", "System"]:
"""
Updates the `System`'s domain `anchor` and `box_size` to encompass all particles. Does not apply any transformations to the state.
Parameters
----------
state : State
The current state of the simulation.
system : System
The current system configuration.
Returns
-------
Tuple[State, System]
The original `State` object (unchanged) and the `System` object
with updated `domain.anchor` and `domain.box_size`.
"""
p_min = jnp.min(state.pos - state.rad[..., None], axis=-2)
p_max = jnp.max(state.pos + state.rad[..., None], axis=-2)
domain = replace(system.domain, box_size=p_max - p_min, anchor=p_min)
return state, replace(system, domain=domain)
__all__ = ["FreeDomain"]