Source code for jaxdem.utils.packingUtils

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project - https://github.com/cdelv/JaxDEM
"""
Utility functions for calculating and changing the packing fraction.
"""

from __future__ import annotations

import jax
import jax.numpy as jnp
from dataclasses import replace
from functools import partial

from typing import TYPE_CHECKING, Tuple

from ..minimizers import minimize
from ..colliders import NeighborList

if TYPE_CHECKING:
    from ..state import State
    from ..system import System


[docs] @jax.jit def compute_particle_volume( state: State, ) -> jax.Array: # This is not the proper instantaneous volume for DPs """Return the total particle volume.""" seg = jax.ops.segment_max(state.volume, state.clump_ID, num_segments=state.N) return jnp.sum(jnp.maximum(seg, 0.0))
[docs] @jax.jit def compute_packing_fraction(state: State, system: System) -> jax.Array: # this assumes that the domain anchor is 0 return compute_particle_volume(state) / jnp.prod(system.domain.box_size)
[docs] @jax.jit def scale_to_packing_fraction( state: State, system: System, new_packing_fraction: float ) -> Tuple[State, System]: # this assumes that the domain anchor is 0 new_box_size_scalar = (compute_particle_volume(state) / new_packing_fraction) ** ( 1 / state.dim ) current_box_L = system.domain.box_size[0] scale_factor = new_box_size_scalar / current_box_L new_box_size = jnp.ones_like(system.domain.box_size) * new_box_size_scalar new_domain = replace(system.domain, box_size=new_box_size) # For spheres and clumps, we can just rescale the positions via state.pos_c * scale_factor # But, for DPs, we need to scale the com positions # Both behaviors can be generalized by scaling the DP com positions, finding the offset # before and after the scaling, and applying the offset to state.pos_c # This preserves the size of the DPs, clumps, and spheres, uniformly total_pos = jax.ops.segment_sum( state.pos_c, state.deformable_ID, num_segments=state.N ) dp_counts = jax.ops.segment_sum( jnp.ones((state.N,), dtype=state.pos_c.dtype), state.deformable_ID, num_segments=state.N, ) dp_com = total_pos / jnp.maximum( dp_counts[:, None], 1.0 ) # avoid divide by zero errors for empty clumps (MAY NOT BE NEEDED) offset = dp_com * scale_factor - dp_com new_state = replace( state, pos_c=state.pos_c + offset[state.deformable_ID] ) # broadcast back and apply shift new_system = replace(system, domain=new_domain) # force rebuild the neighbor list if using it if isinstance(new_system.collider, NeighborList): new_system = replace( new_system, collider=replace(new_system.collider, n_build_times=0), ) return new_state, new_system