# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Utility functions to randomize particle orientations.
"""
from __future__ import annotations
import jax
import jax.numpy as jnp
from typing import TYPE_CHECKING
from functools import partial
from .quaternion import Quaternion
if TYPE_CHECKING:
from ..state import State
[docs]
@jax.jit
@partial(jax.named_call, name="utils.randomize_orientations")
def randomize_orientations(state: State, key: jax.random.KeyArray) -> State:
"""
Randomize orientations for clumps (particles with repeated ``state.ID``),
leaving spheres unchanged.
"""
N = state.N
dim = state.dim # static at trace time (derived from shapes)
def _one(k, ID_i, w_i, xyz_i):
counts = jnp.bincount(ID_i, length=N)
is_clump_member = counts[ID_i] > 1 # (N,)
if dim == 2:
theta_by_id = jax.random.uniform(k, (N,), minval=0.0, maxval=2.0 * jnp.pi)
theta = theta_by_id[ID_i]
w_s = jnp.cos(0.5 * theta)[:, None]
xyz_s = jnp.stack(
[jnp.zeros_like(theta), jnp.zeros_like(theta), jnp.sin(0.5 * theta)],
axis=-1,
)
else: # dim == 3
q4_by_id = jax.random.normal(k, (N, 4))
q4_by_id = q4_by_id / jnp.linalg.norm(q4_by_id, axis=-1, keepdims=True) # uniform rotation
q4 = q4_by_id[ID_i] # same orientation for same clump ID
w_s = q4[:, 0:1]
xyz_s = q4[:, 1:4]
w_new = jnp.where(is_clump_member[:, None], w_s, w_i) # spheres: unchanged
xyz_new = jnp.where(is_clump_member[:, None], xyz_s, xyz_i) # spheres: unchanged
q_new = Quaternion.unit(Quaternion(w_new, xyz_new))
return q_new.w, q_new.xyz
# Match common batching conventions: vmap over axis 0 if present.
lead_ndim = state.ID.ndim - 1 # leading axes before particle axis N
if lead_ndim == 0:
w_new, xyz_new = _one(key, state.ID, state.q.w, state.q.xyz)
state.q = Quaternion(w_new, xyz_new)
elif lead_ndim == 1:
keys = jax.random.split(key, state.ID.shape[0])
w_new, xyz_new = jax.vmap(_one)(keys, state.ID, state.q.w, state.q.xyz)
state.q = Quaternion(w_new, xyz_new)
else:
# For stacked trajectories (or other multi-leading-dim states), flatten
# and then reshape back to preserve the original layout.
lead_shape = state.ID.shape[:-1]
ID = state.ID.reshape((-1, N))
w0 = state.q.w.reshape((-1, N, 1))
xyz0 = state.q.xyz.reshape((-1, N, 3))
keys = jax.random.split(key, ID.shape[0])
w_flat, xyz_flat = jax.vmap(_one)(keys, ID, w0, xyz0)
w_new = w_flat.reshape(lead_shape + (N, 1))
xyz_new = xyz_flat.reshape(lead_shape + (N, 3))
state.q = Quaternion(w_new, xyz_new)
return state