jaxdem.minimizers.fire#

FIRE energy minimizer.

Reference: https://doi.org/10.1103/PhysRevLett.97.170201

Classes

LinearFIRE(alpha_init, f_inc, f_dec, ...)

FIRE energy minimizer for linear DOFs.

RotationFIRE(alpha_init, f_inc, f_dec, ...)

FIRE energy minimizer for rotation DOFs.

class jaxdem.minimizers.fire.LinearFIRE(alpha_init: Array, f_inc: Array, f_dec: Array, f_alpha: Array, N_min: Array, N_bad_max: Array, dt_max_scale: Array, dt_min_scale: Array, dt: Array, dt_min: Array, dt_max: Array, alpha: Array, N_good: Array, N_bad: Array, attempt_couple: Array, coupled: Array, is_master: Array, dt_reverse: Array, velocity_scale: Array)[source]#

Bases: LinearMinimizer

FIRE energy minimizer for linear DOFs.

Notes

  • Adaptive FIRE state (dt, alpha, counters, etc.) lives on this integrator dataclass and is updated functionally via System.

  • No FIRE-specific fields are stored on System or State.

alpha_init: jax.Array#
f_inc: jax.Array#
f_dec: jax.Array#
f_alpha: jax.Array#
N_min: jax.Array#
N_bad_max: jax.Array#
dt_max_scale: jax.Array#
dt_min_scale: jax.Array#
dt: jax.Array#
dt_min: jax.Array#
dt_max: jax.Array#
alpha: jax.Array#
N_good: jax.Array#
N_bad: jax.Array#
attempt_couple: jax.Array#
coupled: jax.Array#
is_master: jax.Array#
dt_reverse: jax.Array#
velocity_scale: jax.Array#
classmethod Create(alpha_init: float = 0.1, f_inc: float = 1.1, f_dec: float = 0.5, f_alpha: float = 0.99, N_min: int = 5, N_bad_max: int = 10, dt_max_scale: float = 10.0, dt_min_scale: float = 0.001, attempt_couple: bool = True) LinearFIRE[source][source]#

Create a LinearFIRE minimizer with JAX array parameters.

Parameters:
  • alpha_init (float, optional) – Initial mixing factor. Default is 0.1.

  • f_inc (float, optional) – Time step increase factor. Default is 1.1.

  • f_dec (float, optional) – Time step decrease factor. Default is 0.5.

  • f_alpha (float, optional) – Mixing factor decrease factor. Default is 0.99.

  • N_min (int, optional) – Minimum number of downhill steps before increasing dt. Default is 5.

  • N_bad_max (int, optional) – Maximum number of uphill steps before stopping. Default is 10.

  • dt_max_scale (float, optional) – Maximum dt scale relative to System.dt. Default is 10.0.

  • dt_min_scale (float, optional) – Minimum dt scale relative to System.dt. Default is 1e-3.

Returns:

A new minimizer instance with JAX array parameters.

Return type:

LinearFIRE

static step_before_force(state: State, system: System) Tuple['State', 'System'][source][source]#

FIRE update and first half of the velocity-Verlet-like step.

static step_after_force(state: State, system: System) Tuple['State', 'System'][source][source]#

Second half of the velocity-Verlet-like step using adaptive dt.

static initialize(state: State, system: System) Tuple['State', 'System'][source][source]#

Initialize FIRE state from the System and current forces.

class jaxdem.minimizers.fire.RotationFIRE(alpha_init: Array, f_inc: Array, f_dec: Array, f_alpha: Array, N_min: Array, N_bad_max: Array, dt_max_scale: Array, dt_min_scale: Array, dt: Array, dt_min: Array, dt_max: Array, alpha: Array, N_good: Array, N_bad: Array, attempt_couple: Array, coupled: Array, is_master: Array, dt_reverse: Array, velocity_scale: Array)[source]#

Bases: RotationMinimizer

FIRE energy minimizer for rotation DOFs.

Notes

  • Adaptive FIRE state (dt, alpha, counters, etc.) lives on this integrator dataclass and is updated functionally via System.

  • No FIRE-specific fields are stored on System or State.

alpha_init: jax.Array#
f_inc: jax.Array#
f_dec: jax.Array#
f_alpha: jax.Array#
N_min: jax.Array#
N_bad_max: jax.Array#
dt_max_scale: jax.Array#
dt_min_scale: jax.Array#
dt: jax.Array#
dt_min: jax.Array#
dt_max: jax.Array#
alpha: jax.Array#
N_good: jax.Array#
N_bad: jax.Array#
attempt_couple: jax.Array#
coupled: jax.Array#
is_master: jax.Array#
dt_reverse: jax.Array#
velocity_scale: jax.Array#
classmethod Create(alpha_init: float = 0.1, f_inc: float = 1.1, f_dec: float = 0.5, f_alpha: float = 0.99, N_min: int = 5, N_bad_max: int = 10, dt_max_scale: float = 10.0, dt_min_scale: float = 0.001, attempt_couple: bool = True) RotationFIRE[source][source]#

Create a RotationFIRE minimizer with JAX array parameters.

Parameters:
  • alpha_init (float, optional) – Initial mixing factor. Default is 0.1.

  • f_inc (float, optional) – Time step increase factor. Default is 1.1.

  • f_dec (float, optional) – Time step decrease factor. Default is 0.5.

  • f_alpha (float, optional) – Mixing factor decrease factor. Default is 0.99.

  • N_min (int, optional) – Minimum number of downhill steps before increasing dt. Default is 5.

  • N_bad_max (int, optional) – Maximum number of uphill steps before stopping. Default is 10.

  • dt_max_scale (float, optional) – Maximum dt scale relative to System.dt. Default is 10.0.

  • dt_min_scale (float, optional) – Minimum dt scale relative to System.dt. Default is 1e-3.

Returns:

A new minimizer instance with JAX array parameters.

Return type:

RotationFIRE

static step_before_force(state: State, system: System) Tuple['State', 'System'][source][source]#

FIRE update and first half of the velocity-Verlet-like step.

static step_after_force(state: State, system: System) Tuple['State', 'System'][source][source]#

Second half of the velocity-Verlet-like step using adaptive dt.

static initialize(state: State, system: System) Tuple['State', 'System'][source][source]#

Initialize FIRE state from the System and current forces.