Source code for jaxdem.minimizers.gradient_descent

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Basic gradient-descent energy minimizer."""

from __future__ import annotations

import jax
import jax.numpy as jnp

from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Tuple

from . import LinearMinimizer, RotationMinimizer
from ..integrators import LinearIntegrator, RotationIntegrator
from ..integrators.velocity_verlet_spiral import omega_dot
from ..utils.quaternion import Quaternion

if TYPE_CHECKING:  # pragma: no cover
    from ..state import State
    from ..system import System


[docs] @LinearMinimizer.register("lineargradientdescent") @LinearIntegrator.register("lineargradientdescent") @jax.tree_util.register_dataclass @dataclass(slots=True) class LinearGradientDescent(LinearMinimizer): learning_rate: jax.Array
[docs] @classmethod def Create(cls, learning_rate: float = 1e-3) -> "LinearGradientDescent": """Create a LinearGradientDescent minimizer with JAX array parameters. Parameters ---------- learning_rate : float, optional Learning rate for gradient descent updates. Default is 1e-3. Returns ------- LinearGradientDescent A new minimizer instance with JAX array parameters. """ return cls(learning_rate=jnp.array(learning_rate))
[docs] @staticmethod @partial(jax.jit, donate_argnames=("state", "system")) @partial(jax.named_call, name="LinearGradientDescent.step_after_force") def step_after_force(state: "State", system: "System") -> Tuple["State", "System"]: """Gradient-descent update using the integrator's learning rate. The learning rate is stored on the LinearGradientDescent dataclass attached to ``system.linear_integrator``, so no mutable state is kept outside the System PyTree. The update equation is simply: .. math:: r_{t+1} = r_{t} + \\gamma F_{t} """ gd = system.linear_integrator lr = gd.learning_rate mask = (1 - state.fixed)[..., None] state.pos_c += lr * state.force * mask return state, system
[docs] @RotationMinimizer.register("rotationgradientdescent") @RotationIntegrator.register("rotationgradientdescent") @jax.tree_util.register_dataclass @dataclass(slots=True) class RotationGradientDescent(RotationMinimizer): learning_rate: jax.Array
[docs] @classmethod def Create(cls, learning_rate: float = 1e-3) -> "RotationGradientDescent": """Create a RotationGradientDescent minimizer with JAX array parameters. Parameters ---------- learning_rate : float, optional Learning rate for gradient descent updates. Default is 1e-3. Returns ------- RotationGradientDescent A new minimizer instance with JAX array parameters. """ return cls(learning_rate=jnp.array(learning_rate))
[docs] @staticmethod @partial(jax.jit, donate_argnames=("state", "system")) @partial(jax.named_call, name="RotationGradientDescent.step_after_force") def step_after_force(state: "State", system: "System") -> Tuple["State", "System"]: """Gradient-descent update using the integrator's learning rate. The learning rate is stored on the RotationGradientDescent dataclass attached to ``system.rotation_integrator``, so no mutable state is kept outside the System PyTree. The update equation is: .. math:: q_{t+1} = q_{t} \cdot e^\left(\\gamma \\tau_t I^{-1}) Where the torque term is a purely imaginary quaternion (scalar part is zero and the vector part is equal to the vector). The exponential map of a purely imaginary quaternion is .. math:: e^u = \cos(|u|) + \frac{\vec{u}}{|u|}\sin(|u|) """ gd = system.rotation_integrator lr = gd.learning_rate # pad torques to 3d if needed if state.dim == 2: torque_lab_3d = jnp.pad(state.torque, ((0, 0), (2, 0)), constant_values=0.0) else: # state.dim == 3 torque_lab_3d = state.torque torque = state.q.rotate_back(state.q, torque_lab_3d) # rotate torques to body frame # calculate angular acceleration due to torques # no angular velocity dependence k = 0.5 * lr * omega_dot(torque * 0.0, torque, state.inertia) * (1 - state.fixed)[..., None] k_norm2 = jnp.sum(k * k, axis=-1, keepdims=True) k_norm = jnp.sqrt(k_norm2) k_norm = jnp.where(k_norm == 0, 1.0, k_norm) # calculate orientation update state.q @= Quaternion( jnp.cos(k_norm), jnp.sin(k_norm) * k / k_norm, ) # normalize the quarternion state.q = state.q.unit(state.q) return state, system