Source code for jaxdem.utils.dynamicsRoutines

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project - https://github.com/cdelv/JaxDEM
"""
Jit-compiled routines for controlling temperature and density via basic rescaling.
"""

from __future__ import annotations

import jax
import jax.numpy as jnp
from dataclasses import replace
from functools import partial
from typing import TYPE_CHECKING, Callable, Optional, Tuple, cast

from .thermal import compute_temperature, scale_to_temperature, set_temperature
from .packingUtils import compute_packing_fraction, scale_to_packing_fraction

if TYPE_CHECKING:  # pragma: no cover
    from ..state import State
    from ..system import System

# Schedule signature:
#   k: 1..K (rescale event index)
#   K: total number of rescale events over the protocol
#   start: initial value
#   target: final value
# returns: setpoint to apply at event k
ScheduleFn = Callable[
    [jax.Array, jax.Array, jax.Array | float, jax.Array | float], jax.Array
]


def _linear_schedule(
    k: jax.Array, K: jax.Array, start: jax.Array | float, target: jax.Array | float
) -> jax.Array:
    Kf = jnp.maximum(K.astype(float), 1.0)
    alpha = k.astype(float) / Kf
    start_arr = jnp.asarray(start, dtype=float)
    target_arr = jnp.asarray(target, dtype=float)
    return start_arr + alpha * (target_arr - start_arr)


def _resolve_target(
    start: jax.Array | float,
    *,
    target: Optional[float],
    delta: Optional[float],
) -> Tuple[bool, jax.Array]:
    """Returns (enabled, final_target). If neither target nor delta is provided, control is disabled."""
    if target is None and delta is None:
        return False, jnp.asarray(start, dtype=float)
    if target is not None and delta is not None:
        # Python-side error at trace time (intentional; ambiguous input)
        raise ValueError("Provide either target=... or delta=..., not both.")
    if target is not None:
        return True, jnp.asarray(target, dtype=float)
    return True, jnp.asarray(start, dtype=float) + jnp.asarray(
        delta, dtype=float
    )  # delta is not None here


def _zero_velocities(state: State, can_rotate: bool) -> State:
    # Use replace to avoid in-place mutation surprises.
    return replace(
        state,
        vel=jnp.zeros_like(state.vel),
        angVel=jnp.zeros_like(state.angVel) * can_rotate,
    )


def _maybe_init_temperature_if_zero(
    state: State,
    *,
    enabled: bool,
    start_setpoint: jax.Array,
    can_rotate: bool,
    subtract_drift: bool,
    k_B: float,
    seed: int,
) -> State:
    """Guard: if initial T == 0 and we intend to control temperature, initialize velocities to start_setpoint."""
    if not enabled:
        return state

    T0 = compute_temperature(state, can_rotate, subtract_drift, k_B)

    def init_nonzero(_: None) -> State:
        # If requested start_setpoint is 0, we can just zero velocities deterministically.
        return jax.lax.cond(
            start_setpoint <= 0.0,
            lambda __: _zero_velocities(state, can_rotate),
            lambda __: set_temperature(
                state,
                cast(float, start_setpoint),
                can_rotate,
                subtract_drift,
                seed=seed,  # IMPORTANT: must not be None inside jit
                k_B=k_B,
            ),
            operand=None,
        )

    return jax.lax.cond(T0 <= 0.0, init_nonzero, lambda _: state, operand=None)


def _controlled_steps_chunk(
    state: State,
    system: System,
    *,
    n: int,
    unroll: int,
    # protocol context (to keep schedules consistent across chunking / rollout)
    step0: jax.Array,  # step_count at protocol start
    total_n: int,  # total integration steps in the whole protocol (not just this chunk)
    rescale_every: int,
    # temperature control config
    temp_enabled: bool,
    T_start: jax.Array | float,
    T_target: jax.Array | float,
    temperature_schedule: Optional[ScheduleFn],
    can_rotate: bool,
    subtract_drift: bool,
    k_B: float,
    # density control config
    dens_enabled: bool,
    pf_start: jax.Array | float,
    pf_target: jax.Array | float,
    density_schedule: Optional[ScheduleFn],
    pf_min: float,
) -> Tuple[State, System]:
    """Runs n integration steps with optional rescaling hooks; jittable and chunkable."""
    f = rescale_every
    if f <= 0:
        # No rescaling at all; just delegate to the fast path.
        return system.step(state, system, n=n)

    schedule_T = (
        _linear_schedule if temperature_schedule is None else temperature_schedule
    )
    schedule_pf = _linear_schedule if density_schedule is None else density_schedule

    # total number of rescale events over the *entire* protocol
    K = ((step0 + total_n) // f) - (step0 // f)

    @partial(jax.named_call, name="dynamicsRoutines._controlled_steps_chunk.body")
    def body(carry: Tuple[State, System], _: None) -> Tuple[Tuple[State, System], None]:
        st, sys = carry

        # --- identical to System._steps body (with the hook inserted later) ---
        sys = replace(sys, time=sys.time + sys.dt, step_count=sys.step_count + 1)

        st, sys = sys.domain.apply(st, sys)
        st, sys = sys.linear_integrator.step_before_force(st, sys)
        st, sys = sys.rotation_integrator.step_before_force(st, sys)

        st, sys = sys.collider.compute_force(st, sys)
        st, sys = sys.force_manager.apply(st, sys)

        st, sys = sys.linear_integrator.step_after_force(st, sys)
        st, sys = sys.rotation_integrator.step_after_force(st, sys)
        # ---------------------------------------------------------------

        do_rescale = (sys.step_count % f) == 0

        def apply_rescale(carry2: Tuple[State, System]) -> Tuple[State, System]:
            st2, sys2 = carry2

            # rescale-event index (1..K) at the current step
            k = (sys2.step_count // f) - (step0 // f)

            # --- temperature rescaling ---
            def do_temp(_: None) -> State:
                T_set = schedule_T(k, K, T_start, T_target)
                T_set = jnp.maximum(T_set, 0.0)

                # Guard: if target is 0, avoid 0/0 in scale_to_temperature; just zero velocities.
                return jax.lax.cond(
                    T_set <= 0.0,
                    lambda __: _zero_velocities(st2, can_rotate),
                    lambda __: scale_to_temperature(
                        st2, cast(float, T_set), can_rotate, subtract_drift, k_B=k_B
                    ),
                    operand=None,
                )

            st3 = jax.lax.cond(temp_enabled, do_temp, lambda _: st2, operand=None)

            # --- density rescaling ---
            def do_dens(_: None) -> Tuple[State, System]:
                pf_set = schedule_pf(k, K, pf_start, pf_target)

                # Guard: if pf_set <= 0, warn and clamp to a tiny positive value to avoid NaNs.
                def warn_and_clamp(_: None) -> jax.Array:
                    # jax.debug.print(
                    #     "Warning: requested packing fraction <= 0 (pf_set={pf}). Clamping to {pf_min}.",
                    #     pf=pf_set,
                    #     pf_min=pf_min,
                    # )
                    return jnp.asarray(pf_min, dtype=float)

                pf_set2 = jax.lax.cond(
                    pf_set <= 0.0, warn_and_clamp, lambda __: pf_set, operand=None
                )
                return scale_to_packing_fraction(st3, sys2, pf_set2)

            st4, sys4 = jax.lax.cond(
                dens_enabled, do_dens, lambda _: (st3, sys2), operand=None
            )
            return st4, sys4

        st, sys = jax.lax.cond(
            do_rescale, apply_rescale, lambda x: x, operand=(st, sys)
        )
        return (st, sys), None

    (state, system), _ = jax.lax.scan(
        body, (state, system), xs=None, length=n, unroll=unroll
    )
    return state, system


[docs] @partial( jax.jit, static_argnames=( "n", "unroll", "rescale_every", "temperature_schedule", "density_schedule", "can_rotate", "subtract_drift", "pf_min", "init_temp_seed", ), donate_argnames=("state", "system"), ) def control_nvt_density( state: State, system: System, *, n: int, rescale_every: int, # temperature control: choose one of (temperature_target, temperature_delta) or neither temperature_target: Optional[float] = None, temperature_delta: Optional[float] = None, # density control: choose one of (packing_fraction_target, packing_fraction_delta) or neither packing_fraction_target: Optional[float] = None, packing_fraction_delta: Optional[float] = None, # dynamics/thermo params can_rotate: bool = True, subtract_drift: bool = True, k_B: float = 1.0, # schedule overrides (must be JIT-static callables) temperature_schedule: Optional[ScheduleFn] = None, density_schedule: Optional[ScheduleFn] = None, # safety controls pf_min: float = 1e-12, init_temp_seed: int = 0, unroll: int = 2, ) -> Tuple[State, System]: """ Runs a protocol for n integration steps, applying (optional) NVT rescaling and/or density rescaling whenever system.step_count is divisible by rescale_every. Notes - rescale_every is in *integration steps* (System.step_count units). - Provide either target or delta for each controlled quantity (or neither to disable). - temperature_schedule / density_schedule must be JIT-static (passed as static_argnames). """ step0 = system.step_count total_n = n # --- determine starts --- T0 = compute_temperature(state, can_rotate, subtract_drift, k_B) pf0 = compute_packing_fraction(state, system) temp_enabled, T_target = _resolve_target( T0, target=temperature_target, delta=temperature_delta ) dens_enabled, pf_target = _resolve_target( pf0, target=packing_fraction_target, delta=packing_fraction_delta ) # For temperature, if T0==0 and control is enabled, initialize to the "start setpoint". # Default start setpoint is just T0 (if nonzero), otherwise use the final target (common desired behavior). T_start = jax.lax.cond(T0 > 0.0, lambda _: T0, lambda _: T_target, operand=None) state = _maybe_init_temperature_if_zero( state, enabled=temp_enabled, start_setpoint=T_start, can_rotate=can_rotate, subtract_drift=subtract_drift, k_B=k_B, seed=init_temp_seed, ) # Recompute starts after possible initialization (important if T0 was 0). T_start = compute_temperature(state, can_rotate, subtract_drift, k_B) pf_start = compute_packing_fraction(state, system) state, system = _controlled_steps_chunk( state, system, n=n, unroll=unroll, step0=step0, total_n=total_n, rescale_every=rescale_every, temp_enabled=temp_enabled, T_start=T_start, T_target=T_target, temperature_schedule=temperature_schedule, can_rotate=can_rotate, subtract_drift=subtract_drift, k_B=k_B, dens_enabled=dens_enabled, pf_start=pf_start, pf_target=pf_target, density_schedule=density_schedule, pf_min=pf_min, ) return state, system
[docs] @partial( jax.jit, static_argnames=( "n", "stride", "unroll", "rescale_every", "temperature_schedule", "density_schedule", "can_rotate", "subtract_drift", "pf_min", "init_temp_seed", ), donate_argnames=("state", "system"), ) def control_nvt_density_rollout( state: State, system: System, *, n: int, # number of saved frames stride: int = 1, # integration steps between frames (like System.trajectory_rollout) rescale_every: int = 1, temperature_target: Optional[float] = None, temperature_delta: Optional[float] = None, packing_fraction_target: Optional[float] = None, packing_fraction_delta: Optional[float] = None, can_rotate: bool = True, subtract_drift: bool = True, k_B: float = 1.0, temperature_schedule: Optional[ScheduleFn] = None, density_schedule: Optional[ScheduleFn] = None, pf_min: float = 1e-12, init_temp_seed: int = 0, unroll: int = 2, ) -> Tuple[State, System, Tuple[State, System]]: """Rollout variant (like System.trajectory_rollout), with globally-consistent schedules across the whole rollout.""" step0 = system.step_count total_n = n * stride T0 = compute_temperature(state, can_rotate, subtract_drift, k_B) pf0 = compute_packing_fraction(state, system) temp_enabled, T_target = _resolve_target( T0, target=temperature_target, delta=temperature_delta ) dens_enabled, pf_target = _resolve_target( pf0, target=packing_fraction_target, delta=packing_fraction_delta ) T_start = jax.lax.cond(T0 > 0.0, lambda _: T0, lambda _: T_target, operand=None) state = _maybe_init_temperature_if_zero( state, enabled=temp_enabled, start_setpoint=T_start, can_rotate=can_rotate, subtract_drift=subtract_drift, k_B=k_B, seed=init_temp_seed, ) T_start = compute_temperature(state, can_rotate, subtract_drift, k_B) pf_start = compute_packing_fraction(state, system) def frame_body( carry: Tuple[State, System], _: None ) -> Tuple[Tuple[State, System], Tuple[State, System]]: st, sys = carry st, sys = _controlled_steps_chunk( st, sys, n=stride, unroll=unroll, step0=step0, total_n=total_n, rescale_every=rescale_every, temp_enabled=temp_enabled, T_start=T_start, T_target=T_target, temperature_schedule=temperature_schedule, can_rotate=can_rotate, subtract_drift=subtract_drift, k_B=k_B, dens_enabled=dens_enabled, pf_start=pf_start, pf_target=pf_target, density_schedule=density_schedule, pf_min=pf_min, ) return (st, sys), (st, sys) if state.batch_size > 1: frame_body = jax.vmap(frame_body, in_axes=(0, None)) (state, system), traj = jax.lax.scan(frame_body, (state, system), xs=None, length=n) return state, system, traj