Source code for jaxdem.utils.quaternion

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Utility functions to handle environments.
"""

from __future__ import annotations

import jax
import jax.numpy as jnp
from jax.typing import ArrayLike

from dataclasses import dataclass
from functools import partial

from .linalg import unit, cross


[docs] @jax.tree_util.register_dataclass @dataclass class Quaternion: """ Quaternion representing the orientation of a particle. Stores the rotation body to lab. """ w: jax.Array # (..., N, 1) xyz: jax.Array # (..., N, 3)
[docs] @staticmethod @partial(jax.named_call, name="Quaternion.create") def create(w: ArrayLike | None = None, xyz: ArrayLike | None = None) -> Quaternion: if w is None: w = jnp.ones((1, 1), dtype=float) w = jnp.asarray(w, dtype=float) if w.ndim == 0: w = w[None] if xyz is None: xyz = jnp.zeros((1, 3), dtype=float) xyz = jnp.asarray(xyz, dtype=float) return Quaternion(w, xyz)
[docs] @staticmethod @partial(jax.jit, inline=True) @partial(jax.named_call, name="Quaternion.unit") def unit(q: Quaternion) -> Quaternion: qvec = unit(jnp.concatenate([q.w, q.xyz], axis=-1)) w, xyz = jnp.split(qvec, [1], axis=-1) return Quaternion(w, xyz)
[docs] @staticmethod @partial(jax.jit, inline=True) @partial(jax.named_call, name="Quaternion.conj") def conj(q: Quaternion) -> Quaternion: return Quaternion(q.w, -q.xyz)
[docs] @staticmethod @partial(jax.jit, inline=True) @partial(jax.named_call, name="Quaternion.inv") def inv(q: Quaternion) -> Quaternion: q = Quaternion.conj(q) return Quaternion.unit(q)
[docs] @staticmethod @partial(jax.jit, inline=True) @partial(jax.named_call, name="Quaternion.rotate") def rotate(q: Quaternion, v: jax.Array) -> jax.Array: """ Rotates a vector v from the body reference frame to the lab reference frame. """ dim = v.shape[-1] if dim == 2: angle = 2.0 * jnp.arctan2(q.xyz[..., -1], q.w[..., 0]) c, s = jnp.cos(angle), jnp.sin(angle) q_w = q.w[..., 0] q_z = q.xyz[..., -1] c = jnp.square(q_w) - jnp.square(q_z) s = 2.0 * q_w * q_z x, y = v[..., 0], v[..., 1] return jnp.stack([c * x - s * y, s * x + c * y], axis=-1) if dim == 3: T = cross(q.xyz, v) B = cross(q.xyz, T) return v + 2 * (q.w * T + B) return v
[docs] @staticmethod @partial(jax.jit, inline=True) @partial(jax.named_call, name="Quaternion.rotate_back") def rotate_back(q: Quaternion, v: jax.Array) -> jax.Array: """ Rotates a vector v from the lab reference frame to the body reference frame. """ q = Quaternion.conj(q) return Quaternion.rotate(q, v)
@partial(jax.jit, inline=True) @partial(jax.named_call, name="Quaternion.__matmul__") def __matmul__(self, other: Quaternion) -> Quaternion: # q @ r w1, w2 = self.w, other.w xyz1, xyz2 = self.xyz, other.xyz w = w1 * w2 - jnp.vecdot(xyz1, xyz2)[..., None] xyz = w1 * xyz2 + w2 * xyz1 + cross(xyz1, xyz2) return Quaternion(w, xyz)