jaxdem.minimizers.optax_optimizer#

Optax-based energy minimizer.

Classes

OptaxOptimizer(opt_state, optimizer)

Optax-based energy minimizer for coupled position and orientation updates.

OptaxRotationNoOp()

Dummy rotation minimizer.

class jaxdem.minimizers.optax_optimizer.OptaxOptimizer(opt_state: Any, optimizer: Any)#

Bases: LinearMinimizer

Optax-based energy minimizer for coupled position and orientation updates.

This minimizer leverages optax to step both positions and quaternions simultaneously based on energy gradients.

opt_state: Any#
optimizer: Any#
classmethod Create(optimizer: Any, state: State) OptaxOptimizer[source]#

Create an OptaxOptimizer with a given optax optimizer.

Parameters:
  • optimizer (optax.GradientTransformation) – The optax optimizer to use (e.g. optax.adam(1e-3)).

  • state (State) – The initial state, used to initialize the optimizer state.

Returns:

A new minimizer instance.

Return type:

OptaxOptimizer

static step_after_force(state: State, system: System) tuple[State, System][source]#
class jaxdem.minimizers.optax_optimizer.OptaxRotationNoOp#

Bases: RotationMinimizer

Dummy rotation minimizer.

Since OptaxOptimizer handles both positions and orientations, this rotation minimizer acts as a no-op to prevent double updates.

classmethod Create() OptaxRotationNoOp[source]#
static step_after_force(state: State, system: System) tuple[State, System][source]#