jaxdem.minimizers#

Energy-minimizer interfaces and implementations.

Classes

LinearMinimizer()

Namespace for translation/linear-state minimizers.

Minimizer()

Abstract base class for energy minimizers.

RotationMinimizer()

Namespace for rotational-state minimizers.

class jaxdem.minimizers.LinearFIRE(alpha_init: Array, f_inc: Array, f_dec: Array, f_alpha: Array, N_min: Array, N_bad_max: Array, dt_max_scale: Array, dt_min_scale: Array, dt: Array, dt_min: Array, dt_max: Array, alpha: Array, N_good: Array, N_bad: Array, attempt_couple: Array, coupled: Array, is_master: Array, dt_reverse: Array, velocity_scale: Array)#

Bases: LinearMinimizer

FIRE energy minimizer for linear DOFs.

Notes

  • Adaptive FIRE state (dt, alpha, counters, etc.) lives on this integrator dataclass and is updated functionally via System.

  • No FIRE-specific fields are stored on System or State.

alpha_init: Array#
f_inc: Array#
f_dec: Array#
f_alpha: Array#
N_min: Array#
N_bad_max: Array#
dt_max_scale: Array#
dt_min_scale: Array#
dt: Array#
dt_min: Array#
dt_max: Array#
alpha: Array#
N_good: Array#
N_bad: Array#
attempt_couple: Array#
coupled: Array#
is_master: Array#
dt_reverse: Array#
velocity_scale: Array#
classmethod Create(alpha_init: float = 0.1, f_inc: float = 1.1, f_dec: float = 0.5, f_alpha: float = 0.99, N_min: int = 5, N_bad_max: int = 10, dt_max_scale: float = 10.0, dt_min_scale: float = 0.001, attempt_couple: bool = True) LinearFIRE[source]#

Create a LinearFIRE minimizer with JAX array parameters.

Parameters:
  • alpha_init (float, optional) – Initial mixing factor. Default is 0.1.

  • f_inc (float, optional) – Time step increase factor. Default is 1.1.

  • f_dec (float, optional) – Time step decrease factor. Default is 0.5.

  • f_alpha (float, optional) – Mixing factor decrease factor. Default is 0.99.

  • N_min (int, optional) – Minimum number of downhill steps before increasing dt. Default is 5.

  • N_bad_max (int, optional) – Maximum number of uphill steps before stopping. Default is 10.

  • dt_max_scale (float, optional) – Maximum dt scale relative to System.dt. Default is 10.0.

  • dt_min_scale (float, optional) – Minimum dt scale relative to System.dt. Default is 1e-3.

Returns:

A new minimizer instance with JAX array parameters.

Return type:

LinearFIRE

static step_before_force(state: State, system: System) tuple[State, System][source]#

FIRE update and first half of the velocity-Verlet-like step.

static step_after_force(state: State, system: System) tuple[State, System][source]#

Second half of the velocity-Verlet-like step using adaptive dt.

static initialize(state: State, system: System) tuple[State, System][source]#

Initialize FIRE state from the System and current forces.

class jaxdem.minimizers.LinearGradientDescent(learning_rate: 'jax.Array')#

Bases: LinearMinimizer

learning_rate: Array#
classmethod Create(learning_rate: float = 0.001) LinearGradientDescent[source]#

Create a LinearGradientDescent minimizer with JAX array parameters.

Parameters:

learning_rate (float, optional) – Learning rate for gradient descent updates. Default is 1e-3.

Returns:

A new minimizer instance with JAX array parameters.

Return type:

LinearGradientDescent

static step_after_force(state: State, system: System) tuple[State, System][source]#

Gradient-descent update using the integrator’s learning rate.

The learning rate is stored on the LinearGradientDescent dataclass attached to system.linear_integrator, so no mutable state is kept outside the System PyTree.

The update equation is simply:

\[\begin{split}r_{t+1} = r_{t} + \\gamma F_{t}\end{split}\]
class jaxdem.minimizers.LinearMinimizer#

Bases: Minimizer

Namespace for translation/linear-state minimizers.

Concrete minimizers (e.g., GradientDescent) should subclass this to signal that they operate on linear kinematics.

class jaxdem.minimizers.Minimizer#

Bases: Integrator, ABC

Abstract base class for energy minimizers.

Notes

  • Minimizer subclasses the generic Integrator interface, so it can be plugged in anywhere an Integrator is expected (e.g., as System.linear_integrator).

  • The default implementations of step_before_force, step_after_force, and initialize are inherited from Integrator and act as no-ops.

  • Concrete minimizers should typically override step_after_force to update the state based on the current forces in an energy-decreasing way.

class jaxdem.minimizers.OptaxOptimizer(opt_state: Any, optimizer: Any)#

Bases: LinearMinimizer

Optax-based energy minimizer for coupled position and orientation updates.

This minimizer leverages optax to step both positions and quaternions simultaneously based on energy gradients.

opt_state: Any#
optimizer: Any#
classmethod Create(optimizer: Any, state: State) OptaxOptimizer[source]#

Create an OptaxOptimizer with a given optax optimizer.

Parameters:
  • optimizer (optax.GradientTransformation) – The optax optimizer to use (e.g. optax.adam(1e-3)).

  • state (State) – The initial state, used to initialize the optimizer state.

Returns:

A new minimizer instance.

Return type:

OptaxOptimizer

static step_after_force(state: State, system: System) tuple[State, System][source]#
class jaxdem.minimizers.OptaxRotationNoOp#

Bases: RotationMinimizer

Dummy rotation minimizer.

Since OptaxOptimizer handles both positions and orientations, this rotation minimizer acts as a no-op to prevent double updates.

classmethod Create() OptaxRotationNoOp[source]#
static step_after_force(state: State, system: System) tuple[State, System][source]#
class jaxdem.minimizers.RotationFIRE(alpha_init: Array, f_inc: Array, f_dec: Array, f_alpha: Array, N_min: Array, N_bad_max: Array, dt_max_scale: Array, dt_min_scale: Array, dt: Array, dt_min: Array, dt_max: Array, alpha: Array, N_good: Array, N_bad: Array, attempt_couple: Array, coupled: Array, is_master: Array, dt_reverse: Array, velocity_scale: Array)#

Bases: RotationMinimizer

FIRE energy minimizer for rotation DOFs.

Notes

  • Adaptive FIRE state (dt, alpha, counters, etc.) lives on this integrator dataclass and is updated functionally via System.

  • No FIRE-specific fields are stored on System or State.

alpha_init: Array#
f_inc: Array#
f_dec: Array#
f_alpha: Array#
N_min: Array#
N_bad_max: Array#
dt_max_scale: Array#
dt_min_scale: Array#
dt: Array#
dt_min: Array#
dt_max: Array#
alpha: Array#
N_good: Array#
N_bad: Array#
attempt_couple: Array#
coupled: Array#
is_master: Array#
dt_reverse: Array#
velocity_scale: Array#
classmethod Create(alpha_init: float = 0.1, f_inc: float = 1.1, f_dec: float = 0.5, f_alpha: float = 0.99, N_min: int = 5, N_bad_max: int = 10, dt_max_scale: float = 10.0, dt_min_scale: float = 0.001, attempt_couple: bool = True) RotationFIRE[source]#

Create a RotationFIRE minimizer with JAX array parameters.

Parameters:
  • alpha_init (float, optional) – Initial mixing factor. Default is 0.1.

  • f_inc (float, optional) – Time step increase factor. Default is 1.1.

  • f_dec (float, optional) – Time step decrease factor. Default is 0.5.

  • f_alpha (float, optional) – Mixing factor decrease factor. Default is 0.99.

  • N_min (int, optional) – Minimum number of downhill steps before increasing dt. Default is 5.

  • N_bad_max (int, optional) – Maximum number of uphill steps before stopping. Default is 10.

  • dt_max_scale (float, optional) – Maximum dt scale relative to System.dt. Default is 10.0.

  • dt_min_scale (float, optional) – Minimum dt scale relative to System.dt. Default is 1e-3.

Returns:

A new minimizer instance with JAX array parameters.

Return type:

RotationFIRE

static step_before_force(state: State, system: System) tuple[State, System][source]#

FIRE update and first half of the velocity-Verlet-like step.

static step_after_force(state: State, system: System) tuple[State, System][source]#

Second half of the velocity-Verlet-like step using adaptive dt.

static initialize(state: State, system: System) tuple[State, System][source]#

Initialize FIRE state from the System and current forces.

class jaxdem.minimizers.RotationGradientDescent(learning_rate: 'jax.Array')#

Bases: RotationMinimizer

learning_rate: Array#
classmethod Create(learning_rate: float = 0.001) RotationGradientDescent[source]#

Create a RotationGradientDescent minimizer with JAX array parameters.

Parameters:

learning_rate (float, optional) – Learning rate for gradient descent updates. Default is 1e-3.

Returns:

A new minimizer instance with JAX array parameters.

Return type:

RotationGradientDescent

static step_after_force(state: State, system: System) tuple[State, System][source]#

Gradient-descent update using the integrator’s learning rate.

The learning rate is stored on the RotationGradientDescent dataclass attached to system.rotation_integrator, so no mutable state is kept outside the System PyTree.

The update equation is:

\[q_{t+1} = q_{t} \cdot e^\left(\gamma \tau_t I^{-1})\]

Where the torque term is a purely imaginary quaternion (scalar part is zero and the vector part is equal to the vector). The exponential map of a purely imaginary quaternion is

\[e^u = \cos(|u|) + \frac{\vec{u}}{|u|}\sin(|u|)\]
class jaxdem.minimizers.RotationMinimizer#

Bases: Minimizer

Namespace for rotational-state minimizers.

Concrete minimizers that relax orientations / angular DOFs should subclass this.

jaxdem.minimizers.minimize(state: State, system: System, max_steps: int = 10000, pe_tol: float = 1e-16, pe_diff_tol: float = 1e-16, initialize: bool = True) tuple[State, System, int, float][source]#

Minimize the energy of the system until either of the following conditions are met: 1. step_count >= max_steps 2. PE <= PE_tol (Energy is low enough) and |``PE / prev_PE - 1``| < pe_diff_tol (Energy stopped changing).

Parameters:
  • state (State) – The state of the system.

  • system (System) – The system to minimize.

  • max_steps (int, optional) – The maximum number of steps to take.

  • pe_tol (float, optional) – The tolerance for the potential energy.

  • pe_diff_tol (float, optional) – The tolerance for the difference in potential energy.

  • initialize (bool, optional) – Whether to initialize the integrator before minimizing.

Returns:

The final state, system, number of steps, and potential energy.

Return type:

Tuple[State, System, int, float]

Notes

  • The potential energy is computed using the compute_potential_energy method of the collider object.

  • The step method of the system object is used to take a single step in the minimization.

  • The jax.lax.while_loop function is used to take steps until the conditions are met.

  • The jax.jit function is used to compile the minimization routine.

Modules

fire

FIRE energy minimizer.

gradient_descent

Basic gradient-descent energy minimizer.

optax_optimizer

Optax-based energy minimizer.

routines

Minimization routines and drivers.