jaxdem.minimizers#
Energy-minimizer interfaces and implementations.
Classes
Namespace for translation/linear-state minimizers. |
|
Abstract base class for energy minimizers. |
|
Namespace for rotational-state minimizers. |
- class jaxdem.minimizers.Minimizer[source]#
Bases:
Integrator,ABCAbstract base class for energy minimizers.
Notes
Minimizer subclasses the generic Integrator interface, so it can be plugged in anywhere an Integrator is expected (e.g., as System.linear_integrator).
The default implementations of step_before_force, step_after_force, and initialize are inherited from Integrator and act as no-ops.
Concrete minimizers should typically override step_after_force to update the state based on the current forces in an energy-decreasing way.
- class jaxdem.minimizers.LinearMinimizer[source]#
Bases:
MinimizerNamespace for translation/linear-state minimizers.
Concrete minimizers (e.g., GradientDescent) should subclass this to signal that they operate on linear kinematics.
- class jaxdem.minimizers.RotationMinimizer[source]#
Bases:
MinimizerNamespace for rotational-state minimizers.
Concrete minimizers that relax orientations / angular DOFs should subclass this.
- class jaxdem.minimizers.LinearGradientDescent(learning_rate: 'jax.Array')[source]#
Bases:
LinearMinimizer- learning_rate: jax.Array#
- classmethod Create(learning_rate: float = 0.001) LinearGradientDescent[source][source]#
Create a LinearGradientDescent minimizer with JAX array parameters.
- Parameters:
learning_rate (float, optional) – Learning rate for gradient descent updates. Default is 1e-3.
- Returns:
A new minimizer instance with JAX array parameters.
- Return type:
- static step_after_force(state: State, system: System) Tuple['State', 'System'][source][source]#
Gradient-descent update using the integrator’s learning rate.
The learning rate is stored on the LinearGradientDescent dataclass attached to
system.linear_integrator, so no mutable state is kept outside the System PyTree.The update equation is simply:
\[r_{t+1} = r_{t} + \gamma F_{t}\]
- class jaxdem.minimizers.RotationGradientDescent(learning_rate: 'jax.Array')[source]#
Bases:
RotationMinimizer- learning_rate: jax.Array#
- classmethod Create(learning_rate: float = 0.001) RotationGradientDescent[source][source]#
Create a RotationGradientDescent minimizer with JAX array parameters.
- Parameters:
learning_rate (float, optional) – Learning rate for gradient descent updates. Default is 1e-3.
- Returns:
A new minimizer instance with JAX array parameters.
- Return type:
- static step_after_force(state: State, system: System) Tuple['State', 'System'][source][source]#
Gradient-descent update using the integrator’s learning rate.
The learning rate is stored on the RotationGradientDescent dataclass attached to
system.rotation_integrator, so no mutable state is kept outside the System PyTree.The update equation is:
\[q_{t+1} = q_{t} \cdot e^\left(\gamma \tau_t I^{-1})\]Where the torque term is a purely imaginary quaternion (scalar part is zero and the vector part is equal to the vector). The exponential map of a purely imaginary quaternion is
\[e^u = \cos(|u|) +\]
- class jaxdem.minimizers.LinearFIRE(alpha_init: Array, f_inc: Array, f_dec: Array, f_alpha: Array, N_min: Array, N_bad_max: Array, dt_max_scale: Array, dt_min_scale: Array, dt: Array, dt_min: Array, dt_max: Array, alpha: Array, N_good: Array, N_bad: Array, attempt_couple: Array, coupled: Array, is_master: Array, dt_reverse: Array, velocity_scale: Array)[source]#
Bases:
LinearMinimizerFIRE energy minimizer for linear DOFs.
Notes
Adaptive FIRE state (
dt,alpha, counters, etc.) lives on this integrator dataclass and is updated functionally viaSystem.No FIRE-specific fields are stored on
SystemorState.
- alpha_init: jax.Array#
- f_inc: jax.Array#
- f_dec: jax.Array#
- f_alpha: jax.Array#
- N_min: jax.Array#
- N_bad_max: jax.Array#
- dt_max_scale: jax.Array#
- dt_min_scale: jax.Array#
- dt: jax.Array#
- dt_min: jax.Array#
- dt_max: jax.Array#
- alpha: jax.Array#
- N_good: jax.Array#
- N_bad: jax.Array#
- attempt_couple: jax.Array#
- coupled: jax.Array#
- is_master: jax.Array#
- dt_reverse: jax.Array#
- velocity_scale: jax.Array#
- classmethod Create(alpha_init: float = 0.1, f_inc: float = 1.1, f_dec: float = 0.5, f_alpha: float = 0.99, N_min: int = 5, N_bad_max: int = 10, dt_max_scale: float = 10.0, dt_min_scale: float = 0.001, attempt_couple: bool = True) LinearFIRE[source][source]#
Create a LinearFIRE minimizer with JAX array parameters.
- Parameters:
alpha_init (float, optional) – Initial mixing factor. Default is 0.1.
f_inc (float, optional) – Time step increase factor. Default is 1.1.
f_dec (float, optional) – Time step decrease factor. Default is 0.5.
f_alpha (float, optional) – Mixing factor decrease factor. Default is 0.99.
N_min (int, optional) – Minimum number of downhill steps before increasing dt. Default is 5.
N_bad_max (int, optional) – Maximum number of uphill steps before stopping. Default is 10.
dt_max_scale (float, optional) – Maximum dt scale relative to System.dt. Default is 10.0.
dt_min_scale (float, optional) – Minimum dt scale relative to System.dt. Default is 1e-3.
- Returns:
A new minimizer instance with JAX array parameters.
- Return type:
- static initialize(state: State, system: System) Tuple['State', 'System'][source][source]#
Initialize FIRE state from the System and current forces.
- class jaxdem.minimizers.RotationFIRE(alpha_init: Array, f_inc: Array, f_dec: Array, f_alpha: Array, N_min: Array, N_bad_max: Array, dt_max_scale: Array, dt_min_scale: Array, dt: Array, dt_min: Array, dt_max: Array, alpha: Array, N_good: Array, N_bad: Array, attempt_couple: Array, coupled: Array, is_master: Array, dt_reverse: Array, velocity_scale: Array)[source]#
Bases:
RotationMinimizerFIRE energy minimizer for rotation DOFs.
Notes
Adaptive FIRE state (
dt,alpha, counters, etc.) lives on this integrator dataclass and is updated functionally viaSystem.No FIRE-specific fields are stored on
SystemorState.
- alpha_init: jax.Array#
- f_inc: jax.Array#
- f_dec: jax.Array#
- f_alpha: jax.Array#
- N_min: jax.Array#
- N_bad_max: jax.Array#
- dt_max_scale: jax.Array#
- dt_min_scale: jax.Array#
- dt: jax.Array#
- dt_min: jax.Array#
- dt_max: jax.Array#
- alpha: jax.Array#
- N_good: jax.Array#
- N_bad: jax.Array#
- attempt_couple: jax.Array#
- coupled: jax.Array#
- is_master: jax.Array#
- dt_reverse: jax.Array#
- velocity_scale: jax.Array#
- classmethod Create(alpha_init: float = 0.1, f_inc: float = 1.1, f_dec: float = 0.5, f_alpha: float = 0.99, N_min: int = 5, N_bad_max: int = 10, dt_max_scale: float = 10.0, dt_min_scale: float = 0.001, attempt_couple: bool = True) RotationFIRE[source][source]#
Create a RotationFIRE minimizer with JAX array parameters.
- Parameters:
alpha_init (float, optional) – Initial mixing factor. Default is 0.1.
f_inc (float, optional) – Time step increase factor. Default is 1.1.
f_dec (float, optional) – Time step decrease factor. Default is 0.5.
f_alpha (float, optional) – Mixing factor decrease factor. Default is 0.99.
N_min (int, optional) – Minimum number of downhill steps before increasing dt. Default is 5.
N_bad_max (int, optional) – Maximum number of uphill steps before stopping. Default is 10.
dt_max_scale (float, optional) – Maximum dt scale relative to System.dt. Default is 10.0.
dt_min_scale (float, optional) – Minimum dt scale relative to System.dt. Default is 1e-3.
- Returns:
A new minimizer instance with JAX array parameters.
- Return type:
- static initialize(state: State, system: System) Tuple['State', 'System'][source][source]#
Initialize FIRE state from the System and current forces.
- jaxdem.minimizers.minimize(state: State, system: System, max_steps: int = 10000, pe_tol: float = 1e-16, pe_diff_tol: float = 1e-16, initialize: bool = True) Tuple[State, System, int, float][source][source]#
Minimize the energy of the system until either of the following conditions are met: 1. step_count >= max_steps 2. PE <= PE_tol (Energy is low enough) and |PE / prev_PE - 1| < pe_diff_tol (Energy stopped changing) :param state: The state of the system. :type state: State :param system: The system to minimize. :type system: System :param max_steps: The maximum number of steps to take. :type max_steps: int, optional :param pe_tol: The tolerance for the potential energy. :type pe_tol: float, optional :param pe_diff_tol: The tolerance for the difference in potential energy. :type pe_diff_tol: float, optional :param initialize: Whether to initialize the integrator before minimizing. :type initialize: bool, optional
- Returns:
The final state, system, number of steps, and potential energy.
- Return type:
Notes
The potential energy is computed using the compute_potential_energy method of the collider object.
The step method of the system object is used to take a single step in the minimization.
The jax.lax.while_loop function is used to take steps until the conditions are met.
The jax.jit function is used to compile the minimization routine.
Modules
FIRE energy minimizer. |
|
Basic gradient-descent energy minimizer. |
|
Minimization routines and drivers. |