Source code for jaxdem.minimizers

"""Energy-minimizer interfaces and implementations."""

from __future__ import annotations

import jax

from abc import ABC
from dataclasses import dataclass
from typing import TYPE_CHECKING, Tuple

from ..factory import Factory
from ..integrators import Integrator

if TYPE_CHECKING:  # pragma: no cover
    from ..state import State
    from ..system import System


# ADD AN IS_CONVERGED METHOD TO THE MINIMIZER BASE CLASS OR MAYBE A BASE CLASS FOR CONVERGENCE CHECKS (i.e. PE OR PRESSURE)


[docs] @jax.tree_util.register_dataclass @dataclass(slots=True) class Minimizer(Integrator, ABC): """ Abstract base class for energy minimizers. Notes ----- - `Minimizer` subclasses the generic `Integrator` interface, so it can be plugged in anywhere an `Integrator` is expected (e.g., as `System.linear_integrator`). - The default implementations of `step_before_force`, `step_after_force`, and `initialize` are inherited from `Integrator` and act as no-ops. - Concrete minimizers should typically override `step_after_force` to update the state based on the current forces in an energy-decreasing way. """
[docs] class LinearMinimizer(Minimizer): """ Namespace for translation/linear-state minimizers. Concrete minimizers (e.g., GradientDescent) should subclass this to signal that they operate on linear kinematics. """
[docs] class RotationMinimizer(Minimizer): """ Namespace for rotational-state minimizers. Concrete minimizers that relax orientations / angular DOFs should subclass this. """
from .gradient_descent import LinearGradientDescent, RotationGradientDescent from .fire import LinearFIRE, RotationFIRE from .routines import minimize __all__ = ["Minimizer", "LinearMinimizer", "RotationMinimizer", "LinearGradientDescent", "RotationGradientDescent", "LinearFIRE", "RotationFIRE", "minimize"]