Source code for jaxdem.utils.angles
# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Utility functions to compute angles between vectors.
"""
from __future__ import annotations
import jax
import jax.numpy as jnp
from .linalg import unit
[docs]
@jax.jit
def signed_angle(v1: jnp.ndarray, v2: jnp.ndarray) -> jnp.ndarray:
r"""
Directional angle from v1 -> v2 around normal :math:`\hat{z}` (right-hand rule), in :math:`[-\pi, \pi)`.
"""
v1 = unit(v1)
v2 = unit(v2)
dot = jnp.vecdot(v1, v2)
sin = v1[..., 0] * v2[..., 1] - v1[..., 1] * v2[..., 0] # ẑ·(a×b)
return jnp.arctan2(sin, dot) # (-π, π]
[docs]
@jax.jit
def signed_angle_x(v1: jnp.ndarray) -> jnp.ndarray:
r"""Directional angle from v1 -> :math:`\hat{x}` around normal :math:`\hat{z}`, in :math:`(-\pi, \pi]`."""
return jnp.arctan2(-v1[..., 1], v1[..., 0])
[docs]
@jax.jit
def angle(v1: jax.Array, v2: jax.Array) -> jax.Array:
r"""
angle from v1 -> v2 in :math:`[0, \pi]`
"""
v1 = unit(v1)
v2 = unit(v2)
y = jnp.linalg.norm(v1 - v2, axis=-1)
x = jnp.linalg.norm(v1 + v2, axis=-1)
return 2.0 * jnp.atan2(y, x)
[docs]
@jax.jit
def angle_x(v1: jax.Array) -> jax.Array:
r"""
angle from v1 -> :math:`\hat{x}` in :math:`[0, \pi]`
"""
v1 = unit(v1)
v2 = jnp.zeros(v1.shape[-1], dtype=v1.dtype).at[0].set(1.0)
y = jnp.linalg.norm(v1 - v2, axis=-1)
x = jnp.linalg.norm(v1 + v2, axis=-1)
return 2.0 * jnp.atan2(y, x)
__all__ = ["signed_angle", "signed_angle_x", "angle", "angle_x"]