Source code for jaxdem.integrator

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Interface for defining time integrators. The time integrator performs one simulation step.
"""

import jax

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Tuple

from .factory import Factory
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from .state import State
    from .system import System


[docs] @jax.tree_util.register_dataclass @dataclass(slots=True) class Integrator(Factory["Integrator"], ABC): """ Abstract base class for defining the interface for time-stepping. Notes ----- - Must Support 2D and 3D domains. - Must be jit compatible Example ------- To define a custom integrator, inherit from :class:`Integrator` and implement its abstract methods: >>> @Integrator.register("myCustomIntegrator") >>> @jax.tree_util.register_dataclass >>> @dataclass(slots=True) >>> class MyCustomIntegrator(Integrator): ... """
[docs] @staticmethod @abstractmethod @jax.jit def step(state: "State", system: "System") -> Tuple["State", "System"]: """ Advance the simulation state by one time step using a specific numerical integration method. Parameters ---------- state : State Current state of the simulation. system : System Simulation system configuration. Returns ------- Tuple[State, System] A tuple containing the updated State and System after one time step of integration. Raises ------ NotImplementedError This is an abstract method and must be implemented by subclasses. Example ------- >>> state, system = system.integrator.step(state, system) """ raise NotImplementedError
[docs] @staticmethod @abstractmethod @jax.jit def initialize(state: "State", system: "System") -> Tuple["State", "System"]: """ Some integration methods require an initialization step, for example LeapFrog. This function implements the interface for the initialization. Parameters ---------- state : State Current state of the simulation. system : System Simulation system configuration. Returns ------- Tuple[State, System] A tuple containing the updated State and System after the initialization. Raises ------ NotImplementedError This is an abstract method and must be implemented by subclasses. Example ------- >>> state, system = system.integrator.initialize(state, system) """ raise NotImplementedError
[docs] @Integrator.register("euler") @jax.tree_util.register_dataclass @dataclass(slots=True) class DirectEuler(Integrator): """ Implements the explicit (forward) Euler integration method. Notes ----- - This method performs the following updates: 1. Applies boundary conditions using :meth:`system.domain.shift`. 2. Computes forces and accelerations using :meth:`system.collider.compute_force`. 3. Updates velocities based on current acceleration. 4. Updates positions based on the newly updated velocities. - Particles with `state.fixed` set to `True` will have their velocities and positions unaffected by the integration step. """
[docs] @staticmethod @jax.jit def step(state: "State", system: "System") -> Tuple["State", "System"]: """ Advances the simulation state by one time step using the Direct Euler method. The update equations are: .. math:: & v(t + \\Delta t) &= v(t) + \\Delta t a(t) \\\\ & r(t + \\Delta t) &= r(t) + \\Delta t v(t + \\Delta t) where: - :math:`r` is the particle position (:attr:`state.pos`) - :math:`v` is the particle velocity (:attr:`state.vel`) - :math:`a` is the particle acceleration (:attr:`state.accel`) - :math:`\\Delta t` is the time step (:attr:`system.dt`) Parameters ---------- state : State Current state of the simulation. system : System Simulation system configuration. Returns ------- Tuple[State, System] The updated state and system after one time step. """ state, system = system.domain.shift(state, system) state, system = system.collider.compute_force(state, system) state.vel += system.dt * state.accel * (1 - state.fixed)[..., None] state.pos += system.dt * state.vel * (1 - state.fixed)[..., None] return state, system
[docs] @staticmethod @jax.jit def initialize(state: "State", system: "System") -> Tuple["State", "System"]: """ The Direct Euler integrator does not require a specific initialization step. Parameters ---------- state : State Current state of the simulation. system : System Simulation system configuration. Returns ------- Tuple[State, System] The original `State` and `System` objects. """ return state, system