# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project - https://github.com/cdelv/JaxDEM
"""Kernel typing + small helpers for JAX analysis.
In the JAX engine, kernels are *pure* functions that operate on arrays directly:
kernel(arrays, t0, t1, **kwargs) -> pytree of arrays
"""
from __future__ import annotations
import jax
import jax.numpy as jnp
from typing import Any, Mapping, Protocol
from .bessel import j0 as j0_bessel
[docs]
class KernelFn(Protocol):
def __call__(
self, arrays: Mapping[str, jax.Array], t0: Any, t1: Any, **kwargs: Any
) -> Any: ...
# ---- Example kernels (Option A layout: (T,N,...) or (T,S,N,...) ) ----
[docs]
def msd_kernel(arrays: Mapping[str, jax.Array], t0: Any, t1: Any) -> jax.Array:
"""Mean-squared displacement.
Works for both:
- pos[t]: (N,d) -> returns () scalar
- pos[t]: (S,N,d) -> returns (S,) vector
"""
pos0 = arrays["pos"][t0]
pos1 = arrays["pos"][t1]
dr = pos1 - pos0
dr -= jnp.mean(dr, axis=-2, keepdims=True) # subtract drift
dr2 = jnp.sum(dr * dr, axis=-1) # (N,) or (S,N)
return jnp.mean(dr2, axis=-1) # () or (S,)
def _spherical_j0(x: jax.Array) -> jax.Array:
"""Spherical Bessel j0(x) = sin(x)/x with a safe x=0 value."""
return jnp.where(x == 0, 1.0, jnp.sin(x) / x)
[docs]
def isf_self_isotropic_kernel(
arrays: Mapping[str, jax.Array], t0: Any, t1: Any, *, k: Any
) -> jax.Array:
"""Self intermediate scattering function (isotropic average).
For isotropic averaging:
- 2D: Fs(k, t) = < J0(k * |dr|) >
- 3D: Fs(k, t) = < sin(k|dr|)/(k|dr|) >
Supports:
- pos[t]: (N,d) -> returns () if k scalar else (K,)
- pos[t]: (S,N,d)-> returns (S,) if k scalar else (S,K)
"""
pos = arrays["pos"]
dr = pos[t1] - pos[t0] # (N,d) or (S,N,d)
d = int(dr.shape[-1])
r = jnp.linalg.norm(dr, axis=-1) # (N,) or (S,N)
k_arr = jnp.asarray(k)
if k_arr.ndim == 0:
x = r * k_arr # (N,) or (S,N)
if d == 2:
phi = j0_bessel(x)
elif d == 3:
phi = _spherical_j0(x)
else:
raise ValueError(
f"isf_self_isotropic_kernel only supports d=2 or d=3, got d={d}"
)
return jnp.mean(phi, axis=-1) # () or (S,)
x = r[..., None] * k_arr # (N,K) or (S,N,K)
if d == 2:
phi = j0_bessel(x)
elif d == 3:
phi = _spherical_j0(x)
else:
raise ValueError(
f"isf_self_isotropic_kernel only supports d=2 or d=3, got d={d}"
)
return jnp.mean(phi, axis=-2) # (K,) or (S,K)
[docs]
def isf_self_kvecs_kernel(
arrays: Mapping[str, jax.Array], t0: Any, t1: Any, *, kvecs: jax.Array
) -> jax.Array:
"""Self ISF for explicit k-vectors: Fs({k}, t) = <cos(k·dr)>."""
pos = arrays["pos"]
dr = pos[t1] - pos[t0] # (N,d) or (S,N,d)
phase = jnp.einsum("...nd,kd->...nk", dr, kvecs) # codespell:ignore nd
return jnp.mean(jnp.cos(phase), axis=-2) # (K,) or (S,K)
[docs]
def unwrap_angles_2d(q_w: jax.Array, q_xyz: jax.Array) -> jax.Array:
"""Convert (T, N, 1) and (T, N, 3) quaternion trajectory to unwrapped cumulative angle (T, N)."""
theta_wrapped = 2.0 * jnp.arctan2(q_xyz[..., 2], q_w[..., 0])
dtheta = jnp.diff(theta_wrapped, axis=0)
dtheta = (dtheta + jnp.pi) % (2 * jnp.pi) - jnp.pi
cumulative = jnp.concatenate(
[theta_wrapped[0:1], theta_wrapped[0:1] + jnp.cumsum(dtheta, axis=0)], axis=0
)
return cumulative
[docs]
def msad_kernel_2d(arrays: Mapping[str, jax.Array], t0: Any, t1: Any) -> jax.Array:
"""Mean-squared angular displacement on unwrapped cumulative angle."""
theta0 = arrays["theta"][t0]
theta1 = arrays["theta"][t1]
dtheta = theta1 - theta0
return jnp.mean(dtheta * dtheta, axis=-1)
[docs]
def isf_angular_kernel_2d(
arrays: Mapping[str, jax.Array], t0: Any, t1: Any, *, theta_0: Any
) -> jax.Array:
"""Angular ISF: <cos(θ₀ · Δθ)>"""
dtheta = arrays["theta"][t1] - arrays["theta"][t0]
theta_0_arr = jnp.asarray(theta_0)
if theta_0_arr.ndim == 0:
return jnp.mean(jnp.cos(theta_0_arr * dtheta), axis=-1)
return jnp.mean(jnp.cos(dtheta[..., None] * theta_0_arr), axis=-2)