# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""FIRE energy minimizer.
Reference: https://doi.org/10.1103/PhysRevLett.97.170201
"""
from __future__ import annotations
import dataclasses
import jax
import jax.numpy as jnp
from dataclasses import dataclass, replace
from functools import partial
from typing import TYPE_CHECKING, Tuple, cast
from . import LinearMinimizer, RotationMinimizer
from ..integrators import LinearIntegrator, RotationIntegrator
from ..integrators.velocity_verlet_spiral import omega_dot
from ..utils.quaternion import Quaternion
if TYPE_CHECKING: # pragma: no cover
from ..state import State
from ..system import System
@partial(jax.jit, inline=True)
@partial(jax.named_call, name="fire._control_update")
def _fire_control_update(
*,
dt: jax.Array,
alpha: jax.Array,
N_good: jax.Array,
N_bad: jax.Array,
dt_min: jax.Array,
dt_max: jax.Array,
alpha_init: jax.Array,
f_inc: jax.Array,
f_dec: jax.Array,
f_alpha: jax.Array,
N_min: jax.Array,
N_bad_max: jax.Array,
mask_free: jax.Array,
power: jax.Array,
) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]:
"""Shared FIRE control-law update.
Returns (dt, alpha, N_good, N_bad, dt_reverse, velocity_scale).
Notes
-----
- This is shared between linear and rotational FIRE so both can use identical
dt/alpha/counter update logic.
- `velocity_scale` is per-particle and is used to zero (or keep) velocities.
"""
def _active_branch(carry):
dt, alpha, N_good, N_bad = carry
def downhill(_):
N_good2 = N_good + 1
N_bad2 = jnp.zeros_like(N_bad)
dt2 = jnp.where(N_good2 > N_min, jnp.minimum(dt * f_inc, dt_max), dt)
alpha2 = jnp.where(N_good2 > N_min, alpha * f_alpha, alpha)
dt_reverse2 = jnp.array(0.0, dtype=dt.dtype)
velocity_scale2 = mask_free
return dt2, alpha2, N_good2, N_bad2, dt_reverse2, velocity_scale2
def uphill(_):
N_good2 = jnp.zeros_like(N_good)
N_bad2 = N_bad + 1
dt_candidate = jnp.maximum(dt * f_dec, dt_min)
alpha2 = alpha_init
dt_reverse2 = -dt_candidate
velocity_scale2 = jnp.zeros_like(mask_free)
done = N_bad2 > N_bad_max
dt2 = jnp.where(done, 0.0, dt_candidate)
return dt2, alpha2, N_good2, N_bad2, dt_reverse2, velocity_scale2
return jax.lax.cond(power > 0.0, downhill, uphill, operand=None)
def _inactive_branch(carry):
dt, alpha, N_good, N_bad = carry
dt_reverse2 = jnp.array(0.0, dtype=dt.dtype)
velocity_scale2 = jnp.zeros_like(mask_free)
alpha2 = jnp.zeros_like(alpha)
return dt, alpha2, N_good, N_bad, dt_reverse2, velocity_scale2
dt, alpha, N_good, N_bad, dt_reverse, velocity_scale = jax.lax.cond(
dt != 0.0,
_active_branch,
_inactive_branch,
operand=(dt, alpha, N_good, N_bad),
)
return dt, alpha, N_good, N_bad, dt_reverse, velocity_scale
[docs]
@LinearMinimizer.register("linearfire")
@LinearIntegrator.register("linearfire")
@jax.tree_util.register_dataclass
@dataclass(slots=True)
class LinearFIRE(LinearMinimizer):
"""FIRE energy minimizer for linear DOFs.
Notes
-----
- Adaptive FIRE state (``dt``, ``alpha``, counters, etc.) lives on this
integrator dataclass and is updated *functionally* via :class:`System`.
- No FIRE-specific fields are stored on :class:`System` or :class:`State`.
"""
# Hyperparameters (as JAX arrays)
alpha_init: jax.Array # initial mixing factor
f_inc: jax.Array # dt increase factor
f_dec: jax.Array # dt decrease factor
f_alpha: jax.Array # mixing factor decrease factor
N_min: jax.Array # minimum number of downhill steps before increasing dt
N_bad_max: jax.Array # maximum number of uphill steps before stopping
dt_max_scale: jax.Array # maximum dt scale relative to System.dt
dt_min_scale: jax.Array # minimum dt scale relative to System.dt
# Adaptive state (updated every step, also as JAX arrays)
dt: jax.Array
dt_min: jax.Array
dt_max: jax.Array
alpha: jax.Array
N_good: jax.Array
N_bad: jax.Array
# FIRE coupling flags (JAX bool scalars; used via lax.cond/where inside jit)
attempt_couple: jax.Array
coupled: jax.Array
is_master: jax.Array
# Shared per-step outputs (stored so a coupled partner can consume them)
dt_reverse: jax.Array
velocity_scale: jax.Array
[docs]
@classmethod
def Create(
cls,
alpha_init: float = 0.1,
f_inc: float = 1.1,
f_dec: float = 0.5,
f_alpha: float = 0.99,
N_min: int = 5,
N_bad_max: int = 10,
dt_max_scale: float = 10.0,
dt_min_scale: float = 1e-3,
attempt_couple: bool = True,
) -> "LinearFIRE":
"""Create a LinearFIRE minimizer with JAX array parameters.
Parameters
----------
alpha_init : float, optional
Initial mixing factor. Default is 0.1.
f_inc : float, optional
Time step increase factor. Default is 1.1.
f_dec : float, optional
Time step decrease factor. Default is 0.5.
f_alpha : float, optional
Mixing factor decrease factor. Default is 0.99.
N_min : int, optional
Minimum number of downhill steps before increasing dt. Default is 5.
N_bad_max : int, optional
Maximum number of uphill steps before stopping. Default is 10.
dt_max_scale : float, optional
Maximum dt scale relative to System.dt. Default is 10.0.
dt_min_scale : float, optional
Minimum dt scale relative to System.dt. Default is 1e-3.
Returns
-------
LinearFIRE
A new minimizer instance with JAX array parameters.
"""
return cls(
alpha_init=jnp.array(alpha_init),
f_inc=jnp.array(f_inc),
f_dec=jnp.array(f_dec),
f_alpha=jnp.array(f_alpha),
N_min=jnp.array(N_min),
N_bad_max=jnp.array(N_bad_max),
dt_max_scale=jnp.array(dt_max_scale),
dt_min_scale=jnp.array(dt_min_scale),
# Initialize adaptive state to zero arrays
dt=jnp.array(0.0),
dt_min=jnp.array(0.0),
dt_max=jnp.array(0.0),
alpha=jnp.array(0.0),
N_good=jnp.array(0),
N_bad=jnp.array(0),
# Coupling defaults (set during initialize if a partner exists)
attempt_couple=jnp.array(attempt_couple),
coupled=jnp.array(False),
is_master=jnp.array(True),
dt_reverse=jnp.array(0.0),
velocity_scale=jnp.array(0.0),
)
[docs]
@staticmethod
@partial(jax.jit, donate_argnames=("state", "system"))
@partial(jax.named_call, name="LinearFIRE.step_before_force")
def step_before_force(state: "State", system: "System") -> Tuple["State", "System"]:
"""FIRE update and first half of the velocity-Verlet-like step."""
fire = system.linear_integrator
dt = fire.dt
dt_min = fire.dt_min
dt_max = fire.dt_max
alpha = fire.alpha
N_good = fire.N_good
N_bad = fire.N_bad
alpha_init = fire.alpha_init
f_inc = fire.f_inc
f_dec = fire.f_dec
f_alpha = fire.f_alpha
N_min = fire.N_min
N_bad_max = fire.N_bad_max
mask_free = (1 - state.fixed)
# FIRE power
power_lin = jnp.sum(state.force * state.vel)
power_global = power_lin + jnp.sum(state.torque * state.angVel)
coupled_master = jnp.logical_and(fire.coupled, fire.is_master)
power = jnp.where(coupled_master, power_global, power_lin)
dt, alpha, N_good, N_bad, dt_reverse, velocity_scale = _fire_control_update(
dt=dt,
alpha=alpha,
N_good=N_good,
N_bad=N_bad,
dt_min=dt_min,
dt_max=dt_max,
alpha_init=alpha_init,
f_inc=f_inc,
f_dec=f_dec,
f_alpha=f_alpha,
N_min=N_min,
N_bad_max=N_bad_max,
mask_free=mask_free,
power=power,
)
# Apply reverse half-step and velocity scaling
state.pos_c += state.vel * mask_free[..., None] * dt_reverse / 2.0
state.vel *= velocity_scale[..., None]
# Velocity Verlet: first half-kick
accel = state.force / state.mass[..., None]
state.vel += accel * mask_free[..., None] * dt / 2.0
# Mix velocities and forces (FIRE projection)
vel_norm = jnp.sqrt(jnp.sum(state.vel ** 2, axis=-1))
force_norm = jnp.sqrt(jnp.sum(state.force ** 2, axis=-1))
mix_mask = (force_norm > 1e-16) * mask_free
mixing_ratio = vel_norm / (force_norm + 1e-16) * alpha * mix_mask
state.vel = (
state.vel * (1.0 - alpha) * mask_free[..., None]
+ state.force * mixing_ratio[..., None]
)
# Re-apply velocity scaling if we stopped motion
state.vel *= velocity_scale[..., None]
state.pos_c += state.vel * mask_free[..., None] * dt / 2.0
# Write back updated FIRE state into the System integrator.
# If coupled, also synchronize the partner integrator's control state so it can
# consume dt/alpha/counters/dt_reverse/velocity_scale in its step_before_force.
new_fire = replace(
fire,
dt=dt,
dt_min=dt_min,
dt_max=dt_max,
alpha=alpha,
N_good=N_good,
N_bad=N_bad,
dt_reverse=dt_reverse,
velocity_scale=velocity_scale,
)
system = dataclasses.replace(system, linear_integrator=new_fire)
if isinstance(system.rotation_integrator, RotationFIRE):
rot_fire = cast(RotationFIRE, system.rotation_integrator)
do_sync = jnp.logical_and(coupled_master, rot_fire.coupled)
def _sync(sys):
return dataclasses.replace(
sys,
rotation_integrator=replace(
rot_fire,
dt=dt,
dt_min=dt_min,
dt_max=dt_max,
alpha=alpha,
N_good=N_good,
N_bad=N_bad,
dt_reverse=dt_reverse,
velocity_scale=velocity_scale,
),
)
system = jax.lax.cond(do_sync, _sync, lambda s: s, system)
return state, system
[docs]
@staticmethod
@partial(jax.jit, donate_argnames=("state", "system"))
@partial(jax.named_call, name="LinearFIRE.step_after_force")
def step_after_force(state: "State", system: "System") -> Tuple["State", "System"]:
"""Second half of the velocity-Verlet-like step using adaptive dt."""
fire = system.linear_integrator
dt = fire.dt
mask_free = (1 - state.fixed)[..., None]
accel = state.force / state.mass[..., None]
state.vel += accel * mask_free * dt / 2.0
return state, system
[docs]
@staticmethod
@jax.jit
@partial(jax.named_call, name="LinearFIRE.initialize")
def initialize(state: "State", system: "System") -> Tuple["State", "System"]:
"""Initialize FIRE state from the System and current forces."""
fire = system.linear_integrator
# Zero initial velocities and compute forces once
state.vel *= 0.0
state, system = system.force_manager.apply(state, system)
state, system = system.collider.compute_force(state, system)
# Calculate the initial parameters
dt0 = system.dt
mask_free = (1 - state.fixed)
fire = replace(
fire,
dt=dt0,
dt_min=dt0 * fire.dt_min_scale,
dt_max=dt0 * fire.dt_max_scale,
alpha=fire.alpha_init,
N_good=0,
N_bad=0,
dt_reverse=jnp.array(0.0, dtype=dt0.dtype),
velocity_scale=mask_free,
)
# Attempt to couple to rotational FIRE if present.
if isinstance(system.rotation_integrator, RotationFIRE):
rot_fire0 = cast(RotationFIRE, system.rotation_integrator)
do_couple = jnp.logical_and(fire.attempt_couple, rot_fire0.attempt_couple)
def _couple(_):
fire2 = replace(fire, coupled=jnp.array(True), is_master=jnp.array(True))
rot_fire2 = replace(
rot_fire0,
# Hyperparams (sync to master)
alpha_init=fire2.alpha_init,
f_inc=fire2.f_inc,
f_dec=fire2.f_dec,
f_alpha=fire2.f_alpha,
N_min=fire2.N_min,
N_bad_max=fire2.N_bad_max,
dt_max_scale=fire2.dt_max_scale,
dt_min_scale=fire2.dt_min_scale,
# Adaptive state
dt=fire2.dt,
dt_min=fire2.dt_min,
dt_max=fire2.dt_max,
alpha=fire2.alpha,
N_good=fire2.N_good,
N_bad=fire2.N_bad,
# Coupling flags / shared outputs
coupled=jnp.array(True),
is_master=jnp.array(False),
dt_reverse=fire2.dt_reverse,
velocity_scale=mask_free,
)
return fire2, rot_fire2
def _no_couple(_):
# Keep the same output PyTree types/shapes as the coupled branch.
rot_fire2 = replace(
rot_fire0,
dt_reverse=jnp.array(0.0, dtype=dt0.dtype),
velocity_scale=mask_free,
)
return replace(fire, coupled=jnp.array(False), is_master=jnp.array(True)), rot_fire2
fire2, rot_fire2 = jax.lax.cond(do_couple, _couple, _no_couple, operand=None)
system = dataclasses.replace(system, linear_integrator=fire2, rotation_integrator=rot_fire2)
else:
fire2 = replace(fire, coupled=jnp.array(False), is_master=jnp.array(True))
system = dataclasses.replace(system, linear_integrator=fire2)
return state, system
[docs]
@RotationMinimizer.register("rotationfire")
@RotationIntegrator.register("rotationfire")
@jax.tree_util.register_dataclass
@dataclass(slots=True)
class RotationFIRE(RotationMinimizer):
"""FIRE energy minimizer for rotation DOFs.
Notes
-----
- Adaptive FIRE state (``dt``, ``alpha``, counters, etc.) lives on this
integrator dataclass and is updated *functionally* via :class:`System`.
- No FIRE-specific fields are stored on :class:`System` or :class:`State`.
"""
# Hyperparameters (as JAX arrays)
alpha_init: jax.Array # initial mixing factor
f_inc: jax.Array # dt increase factor
f_dec: jax.Array # dt decrease factor
f_alpha: jax.Array # mixing factor decrease factor
N_min: jax.Array # minimum number of downhill steps before increasing dt
N_bad_max: jax.Array # maximum number of uphill steps before stopping
dt_max_scale: jax.Array # maximum dt scale relative to System.dt
dt_min_scale: jax.Array # minimum dt scale relative to System.dt
# Adaptive state (updated every step, also as JAX arrays)
dt: jax.Array
dt_min: jax.Array
dt_max: jax.Array
alpha: jax.Array
N_good: jax.Array
N_bad: jax.Array
# FIRE coupling flags (JAX bool scalars; used via lax.cond/where inside jit)
attempt_couple: jax.Array
coupled: jax.Array
is_master: jax.Array
# Shared per-step outputs (stored so a coupled partner can consume them)
dt_reverse: jax.Array
velocity_scale: jax.Array
[docs]
@classmethod
def Create(
cls,
alpha_init: float = 0.1,
f_inc: float = 1.1,
f_dec: float = 0.5,
f_alpha: float = 0.99,
N_min: int = 5,
N_bad_max: int = 10,
dt_max_scale: float = 10.0,
dt_min_scale: float = 1e-3,
attempt_couple: bool = True,
) -> "RotationFIRE":
"""Create a RotationFIRE minimizer with JAX array parameters.
Parameters
----------
alpha_init : float, optional
Initial mixing factor. Default is 0.1.
f_inc : float, optional
Time step increase factor. Default is 1.1.
f_dec : float, optional
Time step decrease factor. Default is 0.5.
f_alpha : float, optional
Mixing factor decrease factor. Default is 0.99.
N_min : int, optional
Minimum number of downhill steps before increasing dt. Default is 5.
N_bad_max : int, optional
Maximum number of uphill steps before stopping. Default is 10.
dt_max_scale : float, optional
Maximum dt scale relative to System.dt. Default is 10.0.
dt_min_scale : float, optional
Minimum dt scale relative to System.dt. Default is 1e-3.
Returns
-------
RotationFIRE
A new minimizer instance with JAX array parameters.
"""
return cls(
alpha_init=jnp.array(alpha_init),
f_inc=jnp.array(f_inc),
f_dec=jnp.array(f_dec),
f_alpha=jnp.array(f_alpha),
N_min=jnp.array(N_min),
N_bad_max=jnp.array(N_bad_max),
dt_max_scale=jnp.array(dt_max_scale),
dt_min_scale=jnp.array(dt_min_scale),
# Initialize adaptive state to zero arrays
dt=jnp.array(0.0),
dt_min=jnp.array(0.0),
dt_max=jnp.array(0.0),
alpha=jnp.array(0.0),
N_good=jnp.array(0),
N_bad=jnp.array(0),
# Coupling defaults (set during initialize if a partner exists)
attempt_couple=jnp.array(attempt_couple),
coupled=jnp.array(False),
is_master=jnp.array(False),
dt_reverse=jnp.array(0.0),
velocity_scale=jnp.array(0.0),
)
[docs]
@staticmethod
@partial(jax.jit, donate_argnames=("state", "system"))
@partial(jax.named_call, name="RotationFIRE.step_before_force")
def step_before_force(state: "State", system: "System") -> Tuple["State", "System"]:
"""FIRE update and first half of the velocity-Verlet-like step."""
fire = system.rotation_integrator
dt = fire.dt
dt_min = fire.dt_min
dt_max = fire.dt_max
alpha = fire.alpha
N_good = fire.N_good
N_bad = fire.N_bad
alpha_init = fire.alpha_init
f_inc = fire.f_inc
f_dec = fire.f_dec
f_alpha = fire.f_alpha
N_min = fire.N_min
N_bad_max = fire.N_bad_max
mask_free = (1 - state.fixed)
# pad to 3d if in 2d
if state.dim == 2:
angVel_lab_3d = jnp.pad(state.angVel, ((0, 0), (2, 0)), constant_values=0.0)
torque_lab_3d = jnp.pad(state.torque, ((0, 0), (2, 0)), constant_values=0.0)
else: # state.dim == 3
angVel_lab_3d = state.angVel
torque_lab_3d = state.torque
# rotate angular velocities and torques to body frame
angVel = state.q.rotate_back(state.q, angVel_lab_3d)
torque = state.q.rotate_back(state.q, torque_lab_3d)
follower = jnp.logical_and(fire.coupled, jnp.logical_not(fire.is_master))
def _consume(_):
return dt, alpha, N_good, N_bad, fire.dt_reverse, fire.velocity_scale
def _update(_):
power = jnp.sum(torque * angVel)
return _fire_control_update(
dt=dt,
alpha=alpha,
N_good=N_good,
N_bad=N_bad,
dt_min=dt_min,
dt_max=dt_max,
alpha_init=alpha_init,
f_inc=f_inc,
f_dec=f_dec,
f_alpha=f_alpha,
N_min=N_min,
N_bad_max=N_bad_max,
mask_free=mask_free,
power=power,
)
dt, alpha, N_good, N_bad, dt_reverse, velocity_scale = jax.lax.cond(
follower, _consume, _update, operand=None
)
# Apply reverse half-step
w_norm2 = jnp.sum(angVel * angVel, axis=-1, keepdims=True)
w_norm = jnp.sqrt(w_norm2)
theta1 = dt_reverse * w_norm / 2
w_norm = jnp.where(w_norm == 0, 1.0, w_norm)
state.q @= Quaternion(
jnp.cos(theta1),
jnp.sin(theta1) * angVel / w_norm,
)
# normalize quarternion
state.q = state.q.unit(state.q)
# Scale velocities
angVel *= velocity_scale[..., None]
# Velocity Verlet: first half-kick
dt_2 = dt / 2
k1 = dt_2 * omega_dot(angVel, torque, state.inertia)
k2 = dt_2 * omega_dot(angVel + k1, torque, state.inertia)
k3 = dt_2 * omega_dot(angVel + 0.25 * (k1 + k2), torque, state.inertia)
angVel += (1 - state.fixed)[..., None] * (k1 + k2 + 4.0 * k3) / 6.0
# Mix angular velocities and torques (FIRE projection)
ang_vel_norm = jnp.sqrt(jnp.sum(angVel * angVel, axis=-1))
torque_norm = jnp.sqrt(jnp.sum(torque * torque, axis=-1))
mix_mask = (torque_norm > 1e-16) * mask_free
mixing_ratio = ang_vel_norm / (torque_norm + 1e-16) * alpha * mix_mask
angVel = (
angVel * (1.0 - alpha) * (1 - state.fixed)[..., None]
+ torque * mixing_ratio[..., None]
)
# Re-apply velocity scaling if we stopped motion
angVel *= velocity_scale[..., None]
# Apply final half-step
w_norm2 = jnp.sum(angVel * angVel, axis=-1, keepdims=True)
w_norm = jnp.sqrt(w_norm2)
theta1 = dt * w_norm / 2
w_norm = jnp.where(w_norm == 0, 1.0, w_norm)
state.q @= Quaternion(
jnp.cos(theta1),
jnp.sin(theta1) * angVel / w_norm,
)
# normalize quarternion
state.q = state.q.unit(state.q)
# rotate angular velocity back to lab frame and save it in the state
if state.dim == 2:
state.angVel = angVel[..., -1:]
else:
state.angVel = state.q.rotate(state.q, angVel)
# Write back updated FIRE state into the System integrator
new_fire = replace(
fire,
dt=dt,
dt_min=dt_min,
dt_max=dt_max,
alpha=alpha,
N_good=N_good,
N_bad=N_bad,
dt_reverse=dt_reverse,
velocity_scale=velocity_scale,
)
system = dataclasses.replace(system, rotation_integrator=new_fire)
return state, system
[docs]
@staticmethod
@partial(jax.jit, donate_argnames=("state", "system"))
@partial(jax.named_call, name="RotationFIRE.step_after_force")
def step_after_force(state: "State", system: "System") -> Tuple["State", "System"]:
"""Second half of the velocity-Verlet-like step using adaptive dt."""
fire = system.rotation_integrator
dt = fire.dt
# pad to 3d if needed
if state.dim == 2:
angVel_lab_3d = jnp.pad(state.angVel, ((0, 0), (2, 0)), constant_values=0.0)
torque_lab_3d = jnp.pad(state.torque, ((0, 0), (2, 0)), constant_values=0.0)
else:
angVel_lab_3d = state.angVel
torque_lab_3d = state.torque
# rotate angular velocities and torques to body frame
angVel = state.q.rotate_back(state.q, angVel_lab_3d)
torque = state.q.rotate_back(state.q, torque_lab_3d)
# update angular velocities
dt_2 = dt / 2
k1 = dt_2 * omega_dot(angVel, torque, state.inertia)
k2 = dt_2 * omega_dot(angVel + k1, torque, state.inertia)
k3 = dt_2 * omega_dot(angVel + (k1 + k2) / 4, torque, state.inertia)
angVel += (1 - state.fixed)[..., None] * (k1 + k2 + 4.0 * k3) / 6.0
# rotate angular velocities back to lab frame and save in state
if state.dim == 2:
state.angVel = angVel[..., -1:]
else:
state.angVel = state.q.rotate(state.q, angVel) # to lab
return state, system
[docs]
@staticmethod
@jax.jit
@partial(jax.named_call, name="RotationFIRE.initialize")
def initialize(state: "State", system: "System") -> Tuple["State", "System"]:
"""Initialize FIRE state from the System and current forces."""
fire = system.rotation_integrator
# Zero initial velocities and compute forces once
state.angVel *= 0.0
state, system = system.force_manager.apply(state, system)
state, system = system.collider.compute_force(state, system)
# Calculate the initial parameters
dt0 = system.dt
mask_free = (1 - state.fixed)
fire = replace(
fire,
dt=dt0,
dt_min=dt0 * fire.dt_min_scale,
dt_max=dt0 * fire.dt_max_scale,
alpha=fire.alpha_init,
N_good=0,
N_bad=0,
dt_reverse=jnp.array(0.0, dtype=dt0.dtype),
velocity_scale=mask_free,
)
# Attempt to couple to linear FIRE if present.
if isinstance(system.linear_integrator, LinearFIRE):
lin_fire0 = cast(LinearFIRE, system.linear_integrator)
do_couple = jnp.logical_and(fire.attempt_couple, lin_fire0.attempt_couple)
def _couple(_):
# Ensure the linear integrator is marked as master/coupled.
lin_fire2 = replace(lin_fire0, coupled=jnp.array(True), is_master=jnp.array(True))
fire2 = replace(
fire,
# Hyperparams + adaptive state from master
alpha_init=lin_fire2.alpha_init,
f_inc=lin_fire2.f_inc,
f_dec=lin_fire2.f_dec,
f_alpha=lin_fire2.f_alpha,
N_min=lin_fire2.N_min,
N_bad_max=lin_fire2.N_bad_max,
dt_max_scale=lin_fire2.dt_max_scale,
dt_min_scale=lin_fire2.dt_min_scale,
dt=lin_fire2.dt,
dt_min=lin_fire2.dt_min,
dt_max=lin_fire2.dt_max,
alpha=lin_fire2.alpha,
N_good=lin_fire2.N_good,
N_bad=lin_fire2.N_bad,
dt_reverse=lin_fire2.dt_reverse,
velocity_scale=mask_free,
coupled=jnp.array(True),
is_master=jnp.array(False),
)
return fire2, lin_fire2
def _no_couple(_):
# Keep the same output PyTree types/shapes as the coupled branch.
lin_fire2 = replace(
lin_fire0,
dt_reverse=jnp.array(0.0, dtype=dt0.dtype),
velocity_scale=mask_free,
)
return replace(fire, coupled=jnp.array(False), is_master=jnp.array(False)), lin_fire2
fire2, lin_fire2 = jax.lax.cond(do_couple, _couple, _no_couple, operand=None)
system = dataclasses.replace(system, rotation_integrator=fire2, linear_integrator=lin_fire2)
else:
fire2 = replace(fire, coupled=jnp.array(False), is_master=jnp.array(False))
system = dataclasses.replace(system, rotation_integrator=fire2)
return state, system