Source code for jaxdem.utils.jamming

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Jamming routines.
https://doi.org/10.1103/PhysRevE.68.011306
"""

from __future__ import annotations

import jax
import jax.numpy as jnp
from dataclasses import replace
from functools import partial

from typing import TYPE_CHECKING, Tuple

from ..minimizers import minimize

if TYPE_CHECKING:
    from ..state import State
    from ..system import System

[docs] @partial(jax.jit, static_argnames=["n_minimization_steps", "n_jamming_steps"]) def bisection_jam(state: State, system: System, n_minimization_steps: int = 1000000, pe_tol: float = 1e-16, pe_diff_tol: float = 1e-16, n_jamming_steps: int = 10000, packing_fraction_tolerance: float = 1e-10, packing_fraction_increment: float = 1e-3) -> Tuple[State, System]: """ Find the nearest jammed state for a given state and system. Uses bisection search with state reversion. Parameters ---------- state : State The state to jam. system : System The system to jam. n_minimization_steps : int, optional The number of steps to take in the minimization. Should be large. Typically 1e6. pe_tol : float, optional The tolerance for the potential energy. Should be very small. Typically 1e-16. pe_diff_tol : float, optional The tolerance for the difference in potential energy across subsequent steps. Should be very small. Typically 1e-16. n_jamming_steps : int, optional The number of steps in the jamming loop. Typically 1e3. packing_fraction_tolerance : float, optional The tolerance for the packing fraction to determine convergence. Typically 1e-10 packing_fraction_increment : float, optional The initial increment for the packing fraction. Typically 1e-3. Larger increments make it faster in the unjammed region, but makes minimization of the earliest detected jammed states take much longer. Returns ------- Tuple[State, System] The jammed state and system. """ # cannot proceed if the initial state is jammed state, system, n_steps, final_pe = minimize(state, system, max_steps=n_minimization_steps, pe_tol=pe_tol, pe_diff_tol=pe_diff_tol, initialize=True) is_initially_jammed = final_pe > pe_tol def print_warning(): jax.debug.print("Warning: Initial state is already jammed (PE={pe} > tol={tol}). Skipping.", pe=final_pe, tol=pe_tol) return None jax.lax.cond(is_initially_jammed, print_warning, lambda: None) jax.debug.print("Initial minimization took {n_steps} steps.", n_steps=n_steps) dim = state.dim initial_packing_fraction = jnp.sum(state.volume) / jnp.prod(system.domain.box_size) init_carry = ( 0, # iteration is_initially_jammed, # is_jammed state, system, # current state/system state, system, # last unjammed state/system initial_packing_fraction, # current packing fraction initial_packing_fraction, # low packing fraction -1.0, # high packing fraction (initially set to -1.0) final_pe, # final potential energy ) def cond_fun(carry): i, is_jammed, _, _, _, _, _, _, _, _ = carry return (i < n_jamming_steps) & (~is_jammed) def body_fun(carry): (i, _, state, system, last_state, last_system, pf, pf_low, pf_high, _) = carry # minimize the state state, system, n_steps, final_pe = minimize( state, system, max_steps=n_minimization_steps, pe_tol=pe_tol, pe_diff_tol=pe_diff_tol, initialize=True ) is_jammed = final_pe > pe_tol def jammed_branch(_): # if jammed, revert to last unjammed state and bisect new_pf_high = pf new_pf = (new_pf_high + pf_low) / 2.0 return new_pf, pf_low, new_pf_high, last_state, last_system, last_state, last_system def unjammed_branch(_): # if unjammed, save current as last unjammed, increment or bisect new_last_state = state new_last_system = system new_pf_low = pf def bisect(): # if a jammed state is known, perform a bisection search return (pf_high + new_pf_low) / 2.0 def increment(): # if no jammed state is known, increment the packing fraction return new_pf_low + packing_fraction_increment new_pf = jax.lax.cond(pf_high > 0, bisect, increment) return new_pf, new_pf_low, pf_high, state, system, new_last_state, new_last_system new_pf, new_pf_low, new_pf_high, new_state, new_system, new_last_state, new_last_system = jax.lax.cond( is_jammed, jammed_branch, unjammed_branch, operand=None ) # check if the packing fraction is converged and print ratio = new_pf_high / new_pf_low is_jammed = (jnp.abs(ratio - 1.0) < packing_fraction_tolerance) & (new_pf_high > 0) # jax.debug.print("Step: {i} - phi={pf}, PE={pe}", i=i+1, pf=pf, pe=final_pe) jax.debug.print("Step: {i} - phi={pf}, PE={pe} after {n_steps} steps", i=i+1, pf=pf, pe=final_pe, n_steps=n_steps) # scale the box and positions new_box_size_scalar = (jnp.sum(new_state.volume) / new_pf) ** (1 / dim) current_box_L = new_system.domain.box_size[0] scale_factor = new_box_size_scalar / current_box_L new_box_size = jnp.ones_like(new_system.domain.box_size) * new_box_size_scalar new_domain = replace(new_system.domain, box_size=new_box_size) next_system = replace(new_system, domain=new_domain) next_state = replace(new_state, pos_c=new_state.pos_c * scale_factor) return (i + 1, is_jammed, next_state, next_system, new_last_state, new_last_system, new_pf, new_pf_low, new_pf_high, final_pe) final_carry = jax.lax.while_loop(cond_fun, body_fun, init_carry) (_, _, _, _, last_state, last_system, final_pf, _, _, final_pe) = final_carry return last_state, last_system, final_pf, final_pe