Source code for jaxdem.utils.linalg

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Utility functions to help with linear algebra.
"""

from __future__ import annotations

import jax
import jax.numpy as jnp


[docs] @jax.jit def unit(v: jax.Array) -> jax.Array: """ Normalize vectors along the last axis. v: (..., D) returns: (..., D), unit vectors; zeros map to zeros. """ norm2 = jnp.vecdot(v, v) return v * jnp.where(norm2 == 0, 1.0, jax.lax.rsqrt(norm2))