jaxdem.minimizers#

Energy-minimizer interfaces and implementations.

Classes

LinearMinimizer()

Namespace for translation/linear-state minimizers.

Minimizer()

Abstract base class for energy minimizers.

RotationMinimizer()

Namespace for rotational-state minimizers.

class jaxdem.minimizers.Minimizer[source]#

Bases: Integrator, ABC

Abstract base class for energy minimizers.

Notes

  • Minimizer subclasses the generic Integrator interface, so it can be plugged in anywhere an Integrator is expected (e.g., as System.linear_integrator).

  • The default implementations of step_before_force, step_after_force, and initialize are inherited from Integrator and act as no-ops.

  • Concrete minimizers should typically override step_after_force to update the state based on the current forces in an energy-decreasing way.

class jaxdem.minimizers.LinearMinimizer[source]#

Bases: Minimizer

Namespace for translation/linear-state minimizers.

Concrete minimizers (e.g., GradientDescent) should subclass this to signal that they operate on linear kinematics.

class jaxdem.minimizers.RotationMinimizer[source]#

Bases: Minimizer

Namespace for rotational-state minimizers.

Concrete minimizers that relax orientations / angular DOFs should subclass this.

class jaxdem.minimizers.LinearGradientDescent(learning_rate: 'jax.Array')[source]#

Bases: LinearMinimizer

learning_rate: jax.Array#
classmethod Create(learning_rate: float = 0.001) LinearGradientDescent[source][source]#

Create a LinearGradientDescent minimizer with JAX array parameters.

Parameters:

learning_rate (float, optional) – Learning rate for gradient descent updates. Default is 1e-3.

Returns:

A new minimizer instance with JAX array parameters.

Return type:

LinearGradientDescent

static step_after_force(state: State, system: System) Tuple['State', 'System'][source][source]#

Gradient-descent update using the integrator’s learning rate.

The learning rate is stored on the LinearGradientDescent dataclass attached to system.linear_integrator, so no mutable state is kept outside the System PyTree.

The update equation is simply:

\[r_{t+1} = r_{t} + \gamma F_{t}\]
class jaxdem.minimizers.RotationGradientDescent(learning_rate: 'jax.Array')[source]#

Bases: RotationMinimizer

learning_rate: jax.Array#
classmethod Create(learning_rate: float = 0.001) RotationGradientDescent[source][source]#

Create a RotationGradientDescent minimizer with JAX array parameters.

Parameters:

learning_rate (float, optional) – Learning rate for gradient descent updates. Default is 1e-3.

Returns:

A new minimizer instance with JAX array parameters.

Return type:

RotationGradientDescent

static step_after_force(state: State, system: System) Tuple['State', 'System'][source][source]#

Gradient-descent update using the integrator’s learning rate.

The learning rate is stored on the RotationGradientDescent dataclass attached to system.rotation_integrator, so no mutable state is kept outside the System PyTree.

The update equation is:

\[q_{t+1} = q_{t} \cdot e^\left(\gamma \tau_t I^{-1})\]

Where the torque term is a purely imaginary quaternion (scalar part is zero and the vector part is equal to the vector). The exponential map of a purely imaginary quaternion is

\[e^u = \cos(|u|) +\]

rac{ ec{u}}{|u|}sin(|u|)

class jaxdem.minimizers.LinearFIRE(alpha_init: Array, f_inc: Array, f_dec: Array, f_alpha: Array, N_min: Array, N_bad_max: Array, dt_max_scale: Array, dt_min_scale: Array, dt: Array, dt_min: Array, dt_max: Array, alpha: Array, N_good: Array, N_bad: Array, attempt_couple: Array, coupled: Array, is_master: Array, dt_reverse: Array, velocity_scale: Array)[source]#

Bases: LinearMinimizer

FIRE energy minimizer for linear DOFs.

Notes

  • Adaptive FIRE state (dt, alpha, counters, etc.) lives on this integrator dataclass and is updated functionally via System.

  • No FIRE-specific fields are stored on System or State.

alpha_init: jax.Array#
f_inc: jax.Array#
f_dec: jax.Array#
f_alpha: jax.Array#
N_min: jax.Array#
N_bad_max: jax.Array#
dt_max_scale: jax.Array#
dt_min_scale: jax.Array#
dt: jax.Array#
dt_min: jax.Array#
dt_max: jax.Array#
alpha: jax.Array#
N_good: jax.Array#
N_bad: jax.Array#
attempt_couple: jax.Array#
coupled: jax.Array#
is_master: jax.Array#
dt_reverse: jax.Array#
velocity_scale: jax.Array#
classmethod Create(alpha_init: float = 0.1, f_inc: float = 1.1, f_dec: float = 0.5, f_alpha: float = 0.99, N_min: int = 5, N_bad_max: int = 10, dt_max_scale: float = 10.0, dt_min_scale: float = 0.001, attempt_couple: bool = True) LinearFIRE[source][source]#

Create a LinearFIRE minimizer with JAX array parameters.

Parameters:
  • alpha_init (float, optional) – Initial mixing factor. Default is 0.1.

  • f_inc (float, optional) – Time step increase factor. Default is 1.1.

  • f_dec (float, optional) – Time step decrease factor. Default is 0.5.

  • f_alpha (float, optional) – Mixing factor decrease factor. Default is 0.99.

  • N_min (int, optional) – Minimum number of downhill steps before increasing dt. Default is 5.

  • N_bad_max (int, optional) – Maximum number of uphill steps before stopping. Default is 10.

  • dt_max_scale (float, optional) – Maximum dt scale relative to System.dt. Default is 10.0.

  • dt_min_scale (float, optional) – Minimum dt scale relative to System.dt. Default is 1e-3.

Returns:

A new minimizer instance with JAX array parameters.

Return type:

LinearFIRE

static initialize(state: State, system: System) Tuple['State', 'System'][source][source]#

Initialize FIRE state from the System and current forces.

static step_after_force(state: State, system: System) Tuple['State', 'System'][source][source]#

Second half of the velocity-Verlet-like step using adaptive dt.

static step_before_force(state: State, system: System) Tuple['State', 'System'][source][source]#

FIRE update and first half of the velocity-Verlet-like step.

class jaxdem.minimizers.RotationFIRE(alpha_init: Array, f_inc: Array, f_dec: Array, f_alpha: Array, N_min: Array, N_bad_max: Array, dt_max_scale: Array, dt_min_scale: Array, dt: Array, dt_min: Array, dt_max: Array, alpha: Array, N_good: Array, N_bad: Array, attempt_couple: Array, coupled: Array, is_master: Array, dt_reverse: Array, velocity_scale: Array)[source]#

Bases: RotationMinimizer

FIRE energy minimizer for rotation DOFs.

Notes

  • Adaptive FIRE state (dt, alpha, counters, etc.) lives on this integrator dataclass and is updated functionally via System.

  • No FIRE-specific fields are stored on System or State.

alpha_init: jax.Array#
f_inc: jax.Array#
f_dec: jax.Array#
f_alpha: jax.Array#
N_min: jax.Array#
N_bad_max: jax.Array#
dt_max_scale: jax.Array#
dt_min_scale: jax.Array#
dt: jax.Array#
dt_min: jax.Array#
dt_max: jax.Array#
alpha: jax.Array#
N_good: jax.Array#
N_bad: jax.Array#
attempt_couple: jax.Array#
coupled: jax.Array#
is_master: jax.Array#
dt_reverse: jax.Array#
velocity_scale: jax.Array#
classmethod Create(alpha_init: float = 0.1, f_inc: float = 1.1, f_dec: float = 0.5, f_alpha: float = 0.99, N_min: int = 5, N_bad_max: int = 10, dt_max_scale: float = 10.0, dt_min_scale: float = 0.001, attempt_couple: bool = True) RotationFIRE[source][source]#

Create a RotationFIRE minimizer with JAX array parameters.

Parameters:
  • alpha_init (float, optional) – Initial mixing factor. Default is 0.1.

  • f_inc (float, optional) – Time step increase factor. Default is 1.1.

  • f_dec (float, optional) – Time step decrease factor. Default is 0.5.

  • f_alpha (float, optional) – Mixing factor decrease factor. Default is 0.99.

  • N_min (int, optional) – Minimum number of downhill steps before increasing dt. Default is 5.

  • N_bad_max (int, optional) – Maximum number of uphill steps before stopping. Default is 10.

  • dt_max_scale (float, optional) – Maximum dt scale relative to System.dt. Default is 10.0.

  • dt_min_scale (float, optional) – Minimum dt scale relative to System.dt. Default is 1e-3.

Returns:

A new minimizer instance with JAX array parameters.

Return type:

RotationFIRE

static initialize(state: State, system: System) Tuple['State', 'System'][source][source]#

Initialize FIRE state from the System and current forces.

static step_after_force(state: State, system: System) Tuple['State', 'System'][source][source]#

Second half of the velocity-Verlet-like step using adaptive dt.

static step_before_force(state: State, system: System) Tuple['State', 'System'][source][source]#

FIRE update and first half of the velocity-Verlet-like step.

jaxdem.minimizers.minimize(state: State, system: System, max_steps: int = 10000, pe_tol: float = 1e-16, pe_diff_tol: float = 1e-16, initialize: bool = True) Tuple[State, System, int, float][source][source]#

Minimize the energy of the system until either of the following conditions are met: 1. step_count >= max_steps 2. PE <= PE_tol (Energy is low enough) and |PE / prev_PE - 1| < pe_diff_tol (Energy stopped changing) :param state: The state of the system. :type state: State :param system: The system to minimize. :type system: System :param max_steps: The maximum number of steps to take. :type max_steps: int, optional :param pe_tol: The tolerance for the potential energy. :type pe_tol: float, optional :param pe_diff_tol: The tolerance for the difference in potential energy. :type pe_diff_tol: float, optional :param initialize: Whether to initialize the integrator before minimizing. :type initialize: bool, optional

Returns:

The final state, system, number of steps, and potential energy.

Return type:

Tuple[State, System, int, float]

Notes

  • The potential energy is computed using the compute_potential_energy method of the collider object.

  • The step method of the system object is used to take a single step in the minimization.

  • The jax.lax.while_loop function is used to take steps until the conditions are met.

  • The jax.jit function is used to compile the minimization routine.

Modules

fire

FIRE energy minimizer.

gradient_descent

Basic gradient-descent energy minimizer.

routines

Minimization routines and drivers.