Source code for jaxdem.utils.clumps

import jax
import jax.numpy as jnp
from jax.scipy.spatial.transform import Rotation

from typing import Tuple, TYPE_CHECKING
from functools import partial

from . import Quaternion

if TYPE_CHECKING:  # pragma: no cover
    from ..state import State
    from ..materials import MaterialTable


@partial(jax.jit, static_argnames=("n", "dim"), inline=True)
def _generate_golden_lattice(n: int, dim: int = 2) -> jax.Array:
    if dim == 2:
        phi = 1.32471795724475
    else:
        phi = 1.22074408460576
    exponents = 1.0 + jax.lax.iota(int, size=dim)
    alphas = 1.0 / jnp.power(phi, exponents)
    ids = 1.0 + jax.lax.iota(int, size=n)
    return (0.5 + jnp.outer(ids, alphas)) % 1.0


[docs] @partial(jax.jit, static_argnames=("n_samples",)) def compute_clump_properties( state: "State", mat_table: "MaterialTable", n_samples: int = 50_000 ) -> "State": dim = state.dim clump_ids = jnp.arange(state.N) counts = jnp.bincount(state.ID, length=state.N) points_u = _generate_golden_lattice(n_samples, dim=state.dim) pos = state.pos def solve_monte_carlo(c_id: jax.Array) -> Tuple[jax.Array, ...]: is_in_clump = state.ID == c_id # --- Bounding Box & Points --- inf = jnp.inf local_min = pos - state.rad[:, None] local_max = pos + state.rad[:, None] clump_min_b = jnp.min(jnp.where(is_in_clump[:, None], local_min, inf), axis=0) clump_max_b = jnp.max(jnp.where(is_in_clump[:, None], local_max, -inf), axis=0) box_vol = jnp.prod(clump_max_b - clump_min_b) points = clump_min_b + points_u * (clump_max_b - clump_min_b) # --- Filter Logic --- eff_rad = jnp.where(is_in_clump, state.rad, 0.0) eff_densities = jnp.where(is_in_clump, mat_table.density[state.mat_id], 0.0) diff = points[:, None, :] - pos[None, :, :] dists_sq = jnp.vecdot(diff, diff) inside_mask = dists_sq < jnp.square(eff_rad[None, :]) densities_per_point = jnp.where(inside_mask, eff_densities[None, :], 0.0) rho = jnp.max(densities_per_point, axis=-1) # --- Mass & COM --- vol_per_sample = box_vol / n_samples total_mass = jnp.sum(rho) * vol_per_sample rho_r = points * rho[:, None] com = jnp.sum(rho_r, axis=0) * vol_per_sample / total_mass # --- Inertia & Orientation --- r_prime = points - com r_sq = jnp.sum(r_prime**2, axis=1) if dim == 3: term1 = jnp.sum( rho[:, None, None] * r_sq[:, None, None] * jnp.eye(3)[None, :, :], axis=0, ) term2 = jnp.einsum("n,ni,nj->ij", rho, r_prime, r_prime) I_tensor = (term1 - term2) * vol_per_sample I_tensor = 0.5 * (I_tensor + I_tensor.T) eigvals, eigvecs = jnp.linalg.eigh(I_tensor) rot = Rotation.from_matrix(eigvecs) q_xyzw = rot.as_quat() q_update = jnp.concatenate([q_xyzw[3:4], q_xyzw[:3]]) return total_mass, com, eigvals, q_update else: # 2D Case: Use Covariance Matrix to determine orientation Cov = jnp.einsum("n,ni,nj->ij", rho, r_prime, r_prime) * vol_per_sample eigvals_cov, eigvecs = jnp.linalg.eigh(Cov) # Convert 2D rotation matrix (eigvecs) to angle theta # Column 0 is the new X-axis theta = jnp.arctan2(eigvecs[1, 0], eigvecs[0, 0]) # Convert angle to Quaternion (rotation around Z) half_theta = theta / 2.0 q_update = jnp.array([jnp.cos(half_theta), 0.0, 0.0, jnp.sin(half_theta)]) # Scalar polar moment of inertia I_scalar = jnp.sum(rho * r_sq) * vol_per_sample I_res = I_scalar.reshape(1) return total_mass, com, I_res, q_update tm, cm, it, qt = jax.vmap(solve_monte_carlo)(clump_ids) is_clump = counts[state.ID] > 1 new_mass = jnp.where(is_clump, tm[state.ID], state.mass) new_com = jnp.where(is_clump[:, None], cm[state.ID], state.pos_c) new_inertia = jnp.where(is_clump[:, None], it[state.ID], state.inertia) new_q_arr = jnp.where( is_clump[:, None], qt[state.ID], jnp.concatenate([state.q.w, state.q.xyz], axis=-1), ) state.mass = new_mass state.pos_c = new_com state.inertia = new_inertia state.q = Quaternion(new_q_arr[..., 0:1], new_q_arr[..., 1:]) state.pos_p = state.q.rotate_back(state.q, pos - state.pos_c) return state