jaxdem.minimizers.optimizers#
Optax custom optimizers for energy minimization.
Functions
|
Damped Newtonian dynamics custom optax optimizer. |
|
Fast Inertial Relaxation Engine (FIRE) custom optax optimizer. |
Classes
|
Custom optax gradient transformation wrapper for DEM energy minimization. |
|
Internal state for the damped Newtonian dynamics optimizer. |
|
Internal state for the Fast Inertial Relaxation Engine (FIRE) optimizer. |
- class jaxdem.minimizers.optimizers.CustomGradientTransformation(init_fn: Any, update_fn: Any, _constructor: Any, kw: dict[str, Any], type_name: str = '')[source]#
Bases:
GradientTransformationExtraArgsCustom optax gradient transformation wrapper for DEM energy minimization.
This class extends optax.GradientTransformationExtraArgs to support serialization and custom equality/hashing for user-defined minimization routines.
- type_name: str#
- kw: dict[str, Any]#
- class jaxdem.minimizers.optimizers.FIREState(vel: jax.Array, dt: jax.Array, alpha: jax.Array, N_good: jax.Array, N_bad: jax.Array)[source]#
Bases:
NamedTupleInternal state for the Fast Inertial Relaxation Engine (FIRE) optimizer.
- vel#
The current velocity parameters of shape (N, d).
- Type:
jax.Array
- dt#
The current step size.
- Type:
jax.Array
- alpha#
The current mixing parameter.
- Type:
jax.Array
- N_good#
Number of consecutive steps with positive power ($P > 0$).
- Type:
jax.Array
- N_bad#
Number of consecutive steps with negative power ($P le 0$).
- Type:
jax.Array
- vel: Array#
Alias for field number 0
- dt: Array#
Alias for field number 1
- alpha: Array#
Alias for field number 2
- N_good: Array#
Alias for field number 3
- N_bad: Array#
Alias for field number 4
- jaxdem.minimizers.optimizers.fire(dt: float, alpha_init: float = 0.1, f_inc: float = 1.1, f_dec: float = 0.5, f_alpha: float = 0.99, N_min: int = 5, N_bad_max: int = 10, dt_max_scale: float = 10.0, dt_min_scale: float = 0.001) Any[source]#
Fast Inertial Relaxation Engine (FIRE) custom optax optimizer.
The FIRE algorithm accelerates or decelerates dynamics depending on the power computed between the force and the velocity. It is a widely used algorithm for energy minimization of granular particles.
Mathematical Formulation#
At each step:
Update the velocities and positions:
\[\begin{split}v_{old} &= v(t) + F(t) \cdot \frac{dt}{2} \\ P &= F(t) \cdot v_{old}\end{split}\]Update the algorithm parameters depending on the power \(P\):
Downhill Step (:math:`P > 0`):
\[\begin{split}N_{good} &\to N_{good} + 1 \\ N_{bad} &\to 0 \\ dt &\to \begin{cases} \min(dt \cdot f_{inc}, dt_{max}) & \text{if } N_{good} > N_{min} \\ dt & \text{otherwise} \end{cases} \\ \alpha &\to \begin{cases} \alpha \cdot f_{\alpha} & \text{if } N_{good} > N_{min} \\ \alpha & \text{otherwise} \end{cases}\end{split}\]Uphill Step (:math:`P le 0`):
\[\begin{split}N_{good} &\to 0 \\ N_{bad} &\to N_{bad} + 1 \\ dt &\to \max(dt \cdot f_{dec}, dt_{min}) \\ \alpha &\to \alpha_{init} \\ v_{old} &\to 0\end{split}\]
Perform velocity mixing:
\[\begin{split}v_{half} &= v_{old} \cdot (1 - \alpha) + \hat{F}(t) \cdot |v_{old}| \cdot \alpha \\ v(t + dt) &= v_{half} + F(t) \cdot \frac{dt}{2}\end{split}\]
- param dt:
The base time step.
- type dt:
float
- param alpha_init:
The initial mixing coefficient.
- type alpha_init:
float, default 0.1
- param f_inc:
The factor by which the time step increases on downhill steps.
- type f_inc:
float, default 1.1
- param f_dec:
The factor by which the time step decreases on uphill steps.
- type f_dec:
float, default 0.5
- param f_alpha:
The decay factor for the mixing coefficient.
- type f_alpha:
float, default 0.99
- param N_min:
The number of consecutive downhill steps required to increase the time step.
- type N_min:
int, default 5
- param N_bad_max:
The maximum number of uphill steps before performing resets.
- type N_bad_max:
int, default 10
- param dt_max_scale:
The maximum time step scale limit: \(dt_{max} = dt \cdot dt_{max\_scale}\).
- type dt_max_scale:
float, default 10.0
- param dt_min_scale:
The minimum time step scale limit: \(dt_{min} = dt \cdot dt_{min\_scale}\).
- type dt_min_scale:
float, default 1e-3
- returns:
CustomGradientTransformation – An optax gradient transformation for the FIRE algorithm.
Reference
———
Bitzek et al., Structural Relaxation Made Simple, Phys. Rev. Lett. 97, 170201 (2006)
- class jaxdem.minimizers.optimizers.DampedNewtonianState(vel: jax.Array, dt: jax.Array)[source]#
Bases:
NamedTupleInternal state for the damped Newtonian dynamics optimizer.
- vel#
The current velocity parameters of shape (N, d).
- Type:
jax.Array
- dt#
The current step size.
- Type:
jax.Array
- vel: Array#
Alias for field number 0
- dt: Array#
Alias for field number 1
- jaxdem.minimizers.optimizers.damped_newtonian(dt: float, gamma: float = 0.5) Any[source]#
Damped Newtonian dynamics custom optax optimizer.
This optimizer implements a velocity-verlet-like scheme with a linear velocity damping term to drive the system toward energy minimization.
Mathematical Formulation#
At each step \(k\), the parameters are advanced using:
\[\begin{split}v_{k} &= \frac{v_{half} + F(t) \cdot \frac{dt}{2}}{1 + \gamma \cdot \frac{dt}{2}} \\ v(t+dt) &= v_{k} \cdot \left(1 - \gamma \cdot \frac{dt}{2}\right) + F(t) \cdot \frac{dt}{2} \\ x(t+dt) &= x(t) + v(t+dt) \cdot dt\end{split}\]- param dt:
The time step.
- type dt:
float
- param gamma:
The damping coefficient.
- type gamma:
float, default 0.5
- returns:
An optax gradient transformation for the damped Newtonian algorithm.
- rtype:
CustomGradientTransformation