jaxdem.minimizers.routines#
Minimization routines and drivers.
Functions
|
Minimize the energy of the system using the configured optax optimizer. |
- jaxdem.minimizers.routines.minimize(state: State, system: System, max_steps: int = 10000, pe_tol: float = 1e-16, pe_diff_tol: float = 1e-16, initialize: bool = True) tuple[State, System, int, float | jax.Array][source]#
Minimize the energy of the system using the configured optax optimizer.
This function runs a JAX-compatible optimization loop using the minimizer specified in system.minimizer. The positions and orientations are packed into a flat parameter array, optimized, and then unpacked back to the returned State.
The optimization loop terminates when any of the following conditions are met:
The number of steps reaches max_steps.
The potential energy per particle drops below pe_tol (or overall energy if system.target_fn is defined).
The relative change in potential energy between successive steps drops below pe_diff_tol:
\[\left|\frac{E_k}{E_{k-1}} - 1\right| < \text{pe\_diff\_tol}\]
- Parameters:
state (State) – The state of the system.
system (System) – The system to minimize.
max_steps (int, default 10000) – The maximum number of optimization steps to take.
pe_tol (float, default 1e-16) – The absolute potential energy tolerance.
pe_diff_tol (float, default 1e-16) – The relative potential energy difference tolerance for convergence.
initialize (bool, default True) – Unused now (maintained for API backward compatibility).
- Returns:
A tuple containing: - The energy-minimized State. - The updated System. - The number of steps actually taken. - The final potential energy.
- Return type: