# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Angular-velocity integrator based on the spiral scheme."""
from __future__ import annotations
import jax
import jax.numpy as jnp
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Tuple
from . import RotationIntegrator
from ..utils.quaternion import Quaternion
from ..utils.linalg import cross
if TYPE_CHECKING: # pragma: no cover
from ..state import State
from ..system import System
@partial(jax.jit, inline=True)
@partial(jax.named_call, name="spiral.omega_dot")
def omega_dot(w: jax.Array, torque: jax.Array, inertia: jax.Array) -> jax.Array:
r"""Compute the time derivative of the angular velocity for diagonal inertia.
Parameters
----------
w : jax.Array
Angular velocity with shape ``(..., N, D)`` where ``D`` is ``1`` for planar
simulations and ``3`` for spatial simulations.
ang_accel : jax.Array
Angular acceleration obtained from external torques divided by the inertia
(same shape as ``w``).
inertia : jax.Array
Diagonal inertia tensor with the same trailing dimension as ``w``.
Returns
-------
jax.Array
:math:`\dot{\boldsymbol{\omega}}`, the angular acceleration consistent with the
rigid-body equations of motion.
"""
D = w.shape[-1]
if D == 1:
return torque / inertia
if D == 3:
return (torque - cross(w, inertia * w)) / inertia
raise ValueError(f"omega_dot supports D in {{1,3}}, got D={D}")
[docs]
@RotationIntegrator.register("spiral")
@jax.tree_util.register_dataclass
@dataclass(slots=True)
class Spiral(RotationIntegrator):
"""
Non-leapfrog spiral integrator for angular velocities.
The implementation follows the velocity update described in
`del Valle et al. (2023) <https://doi.org/10.1016/j.cpc.2023.109077>`_.
"""
[docs]
@staticmethod
@partial(jax.jit, donate_argnames=("state", "system"), inline=True)
@partial(jax.named_call, name="spiral.step_after_force")
def step_after_force(state: "State", system: "System") -> Tuple["State", "System"]:
r"""
Advance angular velocities by a single time step.
A third-order Runge–Kutta scheme (SSPRK3) integrates the rigid-body angular
momentum equations in the principal axis frame. The quaternion is updated based on the spiral
non-leapfrog algorithm.
- SPIRAL algorithm:
.. math::
q(t + \Delta t) = q(t) \cdot e^\left(\frac{\Delta t}{2}\omega\right) \cdot e^\left(\frac{\Delta t^2}{4}\dot{\omega}\right)
Where the angular velocity and its derivative are purely imaginary quaternions (scalar part is zero and the vector part is equal to the vector). The exponential map of a purely imaginary quaternion is
.. math::
e^u = \cos(|u|) + \frac{\vec{u}}{|u|}\sin(|u|)
Angular velocity is then updated using SSPRK3:
.. math::
& \vec{\omega}(t + \Delta t) = \vec{\omega}(t) + \frac{1}{6}(k_1 + k_2 + 4k_3) \\
& k_1 = \Delta t\; \dot{\vec{\omega}}(\vec{\omega}(t + \Delta t / 2), \vec{\tau}(t + \Delta t)) \\
& k_2 = \Delta t\; \dot{\vec{\omega}}(\vec{\omega}(t + \Delta t / 2) + k1, \vec{\tau}(t + \Delta t)) \\
& k_3 = \Delta t\; \dot{\vec{\omega}}(\vec{\omega}(t + \Delta t / 2) + (k1 + k2)/4, \vec{\tau}(t + \Delta t)) \\
Where the angular velocity derivative is a function of the torque and angular velocity:
.. math::
\dot{\vec{\omega}} = (\tau - \vec{\omega} \times (I \vec{\omega}))I^{-1}
Parameters
----------
state : State
Current state of the simulation.
system : System
Simulation system configuration.
Returns
--------
Tuple[State, System]
The updated state and system after one time step.
Reference
----------
del Valle et. al, SPIRAL: An efficient algorithm for the integration of the equation of rotational motion, https://doi.org/10.1016/j.cpc.2023.109077.
Note
-----
- This method donates state and system
- TO DO: make it work without padding the vectors
"""
if state.dim == 2:
angVel_lab_3d = jnp.pad(state.angVel, ((0, 0), (2, 0)), constant_values=0.0)
torque_lab_3d = jnp.pad(state.torque, ((0, 0), (2, 0)), constant_values=0.0)
else: # state.dim == 3
angVel_lab_3d = state.angVel
torque_lab_3d = state.torque
angVel = state.q.rotate_back(state.q, angVel_lab_3d) # to body
torque = state.q.rotate_back(state.q, torque_lab_3d)
w_dot = omega_dot(angVel, torque, state.inertia)
w_norm2 = jnp.sum(angVel * angVel, axis=-1, keepdims=True)
w_dot_norm2 = jnp.sum(w_dot * w_dot, axis=-1, keepdims=True)
w_norm = jnp.sqrt(w_norm2)
w_dot_norm = jnp.sqrt(w_dot_norm2)
theta1 = system.dt * w_norm / 2
theta2 = jnp.power(system.dt, 2) * w_dot_norm / 4
w_norm = jnp.where(w_norm == 0, 1.0, w_norm)
w_dot_norm = jnp.where(w_dot_norm == 0, 1.0, w_dot_norm)
state.q @= Quaternion(
jnp.cos(theta1),
jnp.sin(theta1) * angVel / w_norm,
) @ Quaternion(
jnp.cos(theta2),
jnp.sin(theta2) * w_dot / w_dot_norm,
)
state.q = state.q.unit(state.q)
k1 = system.dt * w_dot
k2 = system.dt * omega_dot(angVel + k1, torque, state.inertia)
k3 = system.dt * omega_dot(angVel + 0.25 * (k1 + k2), torque, state.inertia)
angVel += (1 - state.fixed)[..., None] * (k1 + k2 + 4.0 * k3) / 6.0
if state.dim == 2:
state.angVel = angVel[..., -1:]
else:
state.angVel = state.q.rotate(state.q, angVel) # to lab
return state, system
__all__ = ["Spiral"]