jaxdem.minimizers.routines#

Minimization routines and drivers.

Functions

minimize(state, system[, max_steps, pe_tol, ...])

Minimize the energy of the system using the configured optax optimizer.

jaxdem.minimizers.routines.minimize(state: State, system: System, max_steps: int = 10000, pe_tol: float = 1e-16, pe_diff_tol: float = 1e-16, initialize: bool = True) tuple[State, System, int, float | jax.Array][source]#

Minimize the energy of the system using the configured optax optimizer.

This function runs a JAX-compatible optimization loop using the minimizer specified in system.minimizer. The positions and orientations are packed into a flat parameter array, optimized, and then unpacked back to the returned State.

The optimization loop terminates when any of the following conditions are met:

  1. The number of steps reaches max_steps.

  2. The potential energy per particle drops below pe_tol (or overall energy if system.target_fn is defined).

  3. The relative change in potential energy between successive steps drops below pe_diff_tol:

    \[\left|\frac{E_k}{E_{k-1}} - 1\right| < \text{pe\_diff\_tol}\]
Parameters:
  • state (State) – The state of the system.

  • system (System) – The system to minimize.

  • max_steps (int, default 10000) – The maximum number of optimization steps to take.

  • pe_tol (float, default 1e-16) – The absolute potential energy tolerance.

  • pe_diff_tol (float, default 1e-16) – The relative potential energy difference tolerance for convergence.

  • initialize (bool, default True) – Unused now (maintained for API backward compatibility).

Returns:

A tuple containing: - The energy-minimized State. - The updated System. - The number of steps actually taken. - The final potential energy.

Return type:

Tuple[State, System, int, float | jax.Array]