jaxdem.minimizers.optimizers#

Optax custom optimizers for energy minimization.

Functions

damped_newtonian(dt[, gamma])

Damped Newtonian dynamics custom optax optimizer.

fire(dt[, alpha_init, f_inc, f_dec, ...])

Fast Inertial Relaxation Engine (FIRE) custom optax optimizer.

Classes

CustomGradientTransformation(init_fn, ...[, ...])

Custom optax gradient transformation wrapper for DEM energy minimization.

DampedNewtonianState(vel, dt)

Internal state for the damped Newtonian dynamics optimizer.

FIREState(vel, dt, alpha, N_good, N_bad)

Internal state for the Fast Inertial Relaxation Engine (FIRE) optimizer.

class jaxdem.minimizers.optimizers.CustomGradientTransformation(init_fn: Any, update_fn: Any, _constructor: Any, kw: dict[str, Any], type_name: str = '')[source]#

Bases: GradientTransformationExtraArgs

Custom optax gradient transformation wrapper for DEM energy minimization.

This class extends optax.GradientTransformationExtraArgs to support serialization and custom equality/hashing for user-defined minimization routines.

type_name: str#
kw: dict[str, Any]#
property metadata: dict[str, Any][source]#
class jaxdem.minimizers.optimizers.FIREState(vel: jax.Array, dt: jax.Array, alpha: jax.Array, N_good: jax.Array, N_bad: jax.Array)[source]#

Bases: NamedTuple

Internal state for the Fast Inertial Relaxation Engine (FIRE) optimizer.

vel#

The current velocity parameters of shape (N, d).

Type:

jax.Array

dt#

The current step size.

Type:

jax.Array

alpha#

The current mixing parameter.

Type:

jax.Array

N_good#

Number of consecutive steps with positive power ($P > 0$).

Type:

jax.Array

N_bad#

Number of consecutive steps with negative power ($P le 0$).

Type:

jax.Array

vel: Array#

Alias for field number 0

dt: Array#

Alias for field number 1

alpha: Array#

Alias for field number 2

N_good: Array#

Alias for field number 3

N_bad: Array#

Alias for field number 4

jaxdem.minimizers.optimizers.fire(dt: float, alpha_init: float = 0.1, f_inc: float = 1.1, f_dec: float = 0.5, f_alpha: float = 0.99, N_min: int = 5, N_bad_max: int = 10, dt_max_scale: float = 10.0, dt_min_scale: float = 0.001) Any[source]#

Fast Inertial Relaxation Engine (FIRE) custom optax optimizer.

The FIRE algorithm accelerates or decelerates dynamics depending on the power computed between the force and the velocity. It is a widely used algorithm for energy minimization of granular particles.

Mathematical Formulation#

At each step:

  1. Update the velocities and positions:

    \[\begin{split}v_{old} &= v(t) + F(t) \cdot \frac{dt}{2} \\ P &= F(t) \cdot v_{old}\end{split}\]
  2. Update the algorithm parameters depending on the power \(P\):

    • Downhill Step (:math:`P > 0`):

      \[\begin{split}N_{good} &\to N_{good} + 1 \\ N_{bad} &\to 0 \\ dt &\to \begin{cases} \min(dt \cdot f_{inc}, dt_{max}) & \text{if } N_{good} > N_{min} \\ dt & \text{otherwise} \end{cases} \\ \alpha &\to \begin{cases} \alpha \cdot f_{\alpha} & \text{if } N_{good} > N_{min} \\ \alpha & \text{otherwise} \end{cases}\end{split}\]
    • Uphill Step (:math:`P le 0`):

      \[\begin{split}N_{good} &\to 0 \\ N_{bad} &\to N_{bad} + 1 \\ dt &\to \max(dt \cdot f_{dec}, dt_{min}) \\ \alpha &\to \alpha_{init} \\ v_{old} &\to 0\end{split}\]
  3. Perform velocity mixing:

    \[\begin{split}v_{half} &= v_{old} \cdot (1 - \alpha) + \hat{F}(t) \cdot |v_{old}| \cdot \alpha \\ v(t + dt) &= v_{half} + F(t) \cdot \frac{dt}{2}\end{split}\]
param dt:

The base time step.

type dt:

float

param alpha_init:

The initial mixing coefficient.

type alpha_init:

float, default 0.1

param f_inc:

The factor by which the time step increases on downhill steps.

type f_inc:

float, default 1.1

param f_dec:

The factor by which the time step decreases on uphill steps.

type f_dec:

float, default 0.5

param f_alpha:

The decay factor for the mixing coefficient.

type f_alpha:

float, default 0.99

param N_min:

The number of consecutive downhill steps required to increase the time step.

type N_min:

int, default 5

param N_bad_max:

The maximum number of uphill steps before performing resets.

type N_bad_max:

int, default 10

param dt_max_scale:

The maximum time step scale limit: \(dt_{max} = dt \cdot dt_{max\_scale}\).

type dt_max_scale:

float, default 10.0

param dt_min_scale:

The minimum time step scale limit: \(dt_{min} = dt \cdot dt_{min\_scale}\).

type dt_min_scale:

float, default 1e-3

returns:
  • CustomGradientTransformation – An optax gradient transformation for the FIRE algorithm.

  • Reference

  • ———

  • Bitzek et al., Structural Relaxation Made Simple, Phys. Rev. Lett. 97, 170201 (2006)

class jaxdem.minimizers.optimizers.DampedNewtonianState(vel: jax.Array, dt: jax.Array)[source]#

Bases: NamedTuple

Internal state for the damped Newtonian dynamics optimizer.

vel#

The current velocity parameters of shape (N, d).

Type:

jax.Array

dt#

The current step size.

Type:

jax.Array

vel: Array#

Alias for field number 0

dt: Array#

Alias for field number 1

jaxdem.minimizers.optimizers.damped_newtonian(dt: float, gamma: float = 0.5) Any[source]#

Damped Newtonian dynamics custom optax optimizer.

This optimizer implements a velocity-verlet-like scheme with a linear velocity damping term to drive the system toward energy minimization.

Mathematical Formulation#

At each step \(k\), the parameters are advanced using:

\[\begin{split}v_{k} &= \frac{v_{half} + F(t) \cdot \frac{dt}{2}}{1 + \gamma \cdot \frac{dt}{2}} \\ v(t+dt) &= v_{k} \cdot \left(1 - \gamma \cdot \frac{dt}{2}\right) + F(t) \cdot \frac{dt}{2} \\ x(t+dt) &= x(t) + v(t+dt) \cdot dt\end{split}\]
param dt:

The time step.

type dt:

float

param gamma:

The damping coefficient.

type gamma:

float, default 0.5

returns:

An optax gradient transformation for the damped Newtonian algorithm.

rtype:

CustomGradientTransformation