jaxdem.minimizers.gradient_descent#
Basic gradient-descent energy minimizer.
Classes
|
|
|
- class jaxdem.minimizers.gradient_descent.LinearGradientDescent(learning_rate: 'jax.Array')[source]#
Bases:
LinearMinimizer- learning_rate: jax.Array#
- classmethod Create(learning_rate: float = 0.001) LinearGradientDescent[source][source]#
Create a LinearGradientDescent minimizer with JAX array parameters.
- Parameters:
learning_rate (float, optional) – Learning rate for gradient descent updates. Default is 1e-3.
- Returns:
A new minimizer instance with JAX array parameters.
- Return type:
- static step_after_force(state: State, system: System) Tuple['State', 'System'][source][source]#
Gradient-descent update using the integrator’s learning rate.
The learning rate is stored on the LinearGradientDescent dataclass attached to
system.linear_integrator, so no mutable state is kept outside the System PyTree.The update equation is simply:
\[r_{t+1} = r_{t} + \gamma F_{t}\]
- class jaxdem.minimizers.gradient_descent.RotationGradientDescent(learning_rate: 'jax.Array')[source]#
Bases:
RotationMinimizer- learning_rate: jax.Array#
- classmethod Create(learning_rate: float = 0.001) RotationGradientDescent[source][source]#
Create a RotationGradientDescent minimizer with JAX array parameters.
- Parameters:
learning_rate (float, optional) – Learning rate for gradient descent updates. Default is 1e-3.
- Returns:
A new minimizer instance with JAX array parameters.
- Return type:
- static step_after_force(state: State, system: System) Tuple['State', 'System'][source][source]#
Gradient-descent update using the integrator’s learning rate.
The learning rate is stored on the RotationGradientDescent dataclass attached to
system.rotation_integrator, so no mutable state is kept outside the System PyTree.The update equation is:
\[q_{t+1} = q_{t} \cdot e^\left(\gamma \tau_t I^{-1})\]Where the torque term is a purely imaginary quaternion (scalar part is zero and the vector part is equal to the vector). The exponential map of a purely imaginary quaternion is
\[e^u = \cos(|u|) +\]