# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Angular-velocity integrator based on the spiral scheme."""
from __future__ import annotations
import jax
import jax.numpy as jnp
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Tuple
from . import RotationIntegrator
from ..utils.quaternion import Quaternion
from ..utils.linalg import cross
if TYPE_CHECKING: # pragma: no cover
from ..state import State
from ..system import System
@partial(jax.jit, inline=True)
@partial(jax.named_call, name="spiral.omega_dot")
def omega_dot(w: jax.Array, torque: jax.Array, inertia: jax.Array) -> jax.Array:
r"""Compute the time derivative of the angular velocity for diagonal inertia.
Parameters
----------
w : jax.Array
Angular velocity with shape ``(..., N, D)`` where ``D`` is ``1`` for planar
simulations and ``3`` for spatial simulations.
ang_accel : jax.Array
Angular acceleration obtained from external torques divided by the inertia
(same shape as ``w``).
inertia : jax.Array
Diagonal inertia tensor with the same trailing dimension as ``w``.
Returns
-------
jax.Array
:math:`\dot{\boldsymbol{\omega}}`, the angular acceleration consistent with the
rigid-body equations of motion.
"""
D = w.shape[-1]
if D == 1:
return torque / inertia
if D == 3:
return (torque - cross(w, inertia * w)) / inertia
raise ValueError(f"omega_dot supports D in {{1,3}}, got D={D}")
[docs]
@RotationIntegrator.register("verletspiral")
@jax.tree_util.register_dataclass
@dataclass(slots=True)
class VelocityVerletSpiral(RotationIntegrator):
"""
Leapfrog spiral integrator for angular velocities adapted to Velocity Verlet.
The implementation follows the velocity update described in
`del Valle et al. (2023) <https://doi.org/10.1016/j.cpc.2023.109077>`_.
"""
[docs]
@staticmethod
@partial(jax.jit, donate_argnames=("state", "system"), inline=True)
@partial(jax.named_call, name="VelocityVerletSpiral.step_before_force")
def step_before_force(state: "State", system: "System") -> Tuple["State", "System"]:
r"""
Advances the simulation state by one half-step before the force calculation using the Velocity Verlet scheme.
A third-order Runge–Kutta scheme (SSPRK3) integrates the rigid-body angular
momentum equations in the principal axis frame. The quaternion is updated based on the spiral
leapfrog algorithm to implement a velocity verlet like version.
- SPIRAL algorithm:
.. math::
q(t + \Delta t) = q(t) \cdot e^\left(\frac{\Delta t}{2}\omega(t + \Delta t/2)\right)
Where the angular velocity and its derivative are purely imaginary quaternions (scalar part is zero and the vector part is equal to the vector). The exponential map of a purely imaginary quaternion is
.. math::
e^u = \cos(|u|) + \frac{\vec{u}}{|u|}\sin(|u|)
Angular velocity is then updated using SSPRK3 wich we:
.. math::
& \vec{\omega}(t + \Delta t/2) = \vec{\omega}(t) + \frac{1}{6}(k_1 + k_2 + 4k_3) \\
& k_1 = \Delta t/2\; \dot{\vec{\omega}}(\vec{\omega}(t), \vec{\tau}(t)) \\
& k_2 = \Delta t/2\; \dot{\vec{\omega}}(\vec{\omega}(t) + k1, \vec{\tau}(t)) \\
& k_3 = \Delta t/2\; \dot{\vec{\omega}}(\vec{\omega}(t) + (k1 + k2)/4, \vec{\tau}(t)) \\
Where the angular velocity derivative is a function of the torque and angular velocity:
.. math::
\dot{\vec{\omega}} = (\tau - \vec{\omega} \times (I \vec{\omega}))I^{-1}
Parameters
----------
state : State
Current state of the simulation.
system : System
Simulation system configuration.
Returns
-------
Tuple[State, System]
The updated state and system after one time step.
Reference
-----------
del Valle et. al, SPIRAL: An efficient algorithm for the integration of the equation of rotational motion, https://doi.org/10.1016/j.cpc.2023.109077.
Note
-----
- This method donates state and system
"""
if state.dim == 2:
angVel_lab_3d = jnp.pad(state.angVel, ((0, 0), (2, 0)), constant_values=0.0)
torque_lab_3d = jnp.pad(state.torque, ((0, 0), (2, 0)), constant_values=0.0)
else: # state.dim == 3
angVel_lab_3d = state.angVel
torque_lab_3d = state.torque
angVel = state.q.rotate_back(state.q, angVel_lab_3d) # to body
torque = state.q.rotate_back(state.q, torque_lab_3d)
dt_2 = system.dt / 2
k1 = dt_2 * omega_dot(angVel, torque, state.inertia)
k2 = dt_2 * omega_dot(angVel + k1, torque, state.inertia)
k3 = dt_2 * omega_dot(angVel + 0.25 * (k1 + k2), torque, state.inertia)
angVel += (1 - state.fixed)[..., None] * (k1 + k2 + 4.0 * k3) / 6.0
w_norm2 = jnp.sum(angVel * angVel, axis=-1, keepdims=True)
w_norm = jnp.sqrt(w_norm2)
theta1 = system.dt * w_norm / 2
w_norm = jnp.where(w_norm == 0, 1.0, w_norm)
state.q @= Quaternion(
jnp.cos(theta1),
jnp.sin(theta1) * angVel / w_norm,
)
state.q = state.q.unit(state.q)
if state.dim == 2:
state.angVel = angVel[..., -1:]
else:
state.angVel = state.q.rotate(state.q, angVel) # to lab
return state, system
[docs]
@staticmethod
@partial(jax.jit, donate_argnames=("state", "system"), inline=True)
@partial(jax.named_call, name="VelocityVerletSpiral.step_after_force")
def step_after_force(state: "State", system: "System") -> Tuple["State", "System"]:
r"""
Advances the simulation state by one half-step after the force calculation using the Velocity Verlet scheme.
A third-order Runge–Kutta scheme (SSPRK3) integrates the rigid-body angular
momentum equations in the principal axis frame. The quaternion is updated based on the spiral
leapfrog algorithm to implement a velocity verlet like version.
.. math::
& \vec{\omega}(t + \Delta t) = \vec{\omega}(t + \Delta t/2) + \frac{1}{6}(k_1 + k_2 + 4k_3) \\
& k_1 = \Delta t/2\; \dot{\vec{\omega}}(\vec{\omega}(t), \vec{\tau}(t)) \\
& k_2 = \Delta t/2\; \dot{\vec{\omega}}(\vec{\omega}(t) + k1, \vec{\tau}(t)) \\
& k_3 = \Delta t/2\; \dot{\vec{\omega}}(\vec{\omega}(t) + (k1 + k2)/4, \vec{\tau}(t)) \\
Where the angular velocity derivative is a function of the torque and angular velocity:
.. math::
\dot{\vec{\omega}} = (\tau + \vec{\omega} \times (I \vec{\omega}))I^{-1}
Parameters
----------
state : State
Current state of the simulation.
system : System
Simulation system configuration.
Returns
-------
Tuple[State, System]
The updated state and system after one time step.
Reference
-----------
del Valle et. al, SPIRAL: An efficient algorithm for the integration of the equation of rotational motion, https://doi.org/10.1016/j.cpc.2023.109077.
Note
-----
- This method donates state and system
"""
if state.dim == 2:
angVel_lab_3d = jnp.pad(state.angVel, ((0, 0), (2, 0)), constant_values=0.0)
torque_lab_3d = jnp.pad(state.torque, ((0, 0), (2, 0)), constant_values=0.0)
else:
angVel_lab_3d = state.angVel
torque_lab_3d = state.torque
angVel = state.q.rotate_back(state.q, angVel_lab_3d) # to body
torque = state.q.rotate_back(state.q, torque_lab_3d)
dt_2 = system.dt / 2
k1 = dt_2 * omega_dot(angVel, torque, state.inertia)
k2 = dt_2 * omega_dot(angVel + k1, torque, state.inertia)
k3 = dt_2 * omega_dot(angVel + (k1 + k2) / 4, torque, state.inertia)
angVel += (1 - state.fixed)[..., None] * (k1 + k2 + 4.0 * k3) / 6.0
if state.dim == 2:
state.angVel = angVel[..., -1:]
else:
state.angVel = state.q.rotate(state.q, angVel) # to lab
return state, system
__all__ = ["VelocityVerletSpiral"]