jaxdem.rl.trainers#
Interface for defining reinforcement learning model trainers.
Classes
|
Base class for reinforcement learning trainers. |
|
Container for rollout data (single step or stacked across time). |
- class jaxdem.rl.trainers.PPOTrainer(env: Environment, graphdef: nnx.GraphDef, graphstate: nnx.GraphState, key: ArrayLike, advantage_gamma: jax.Array, advantage_lambda: jax.Array, advantage_rho_clip: jax.Array, advantage_c_clip: jax.Array, ppo_clip_eps: jax.Array, ppo_value_coeff: jax.Array, ppo_entropy_coeff: jax.Array, importance_sampling_alpha: jax.Array, importance_sampling_beta: jax.Array, anneal_importance_sampling_beta: jax.Array, num_envs: int = 4096, num_epochs: int = 2000, num_steps_epoch: int = 128, num_minibatches: int = 8, minibatch_size: int = 512, num_segments: int = 4096)[source]#
Bases:
Trainer
Proximal Policy Optimization (PPO) trainer in PufferLib style.
This trainer implements the PPO algorithm with clipped surrogate objectives, value-function loss, entropy regularization, and importance-sampling reweighting.
Loss function
Given a trajectory batch with actions \(a_t\), states \(s_t\), rewards \(r_t\), advantages \(A_t\), and old log-probabilities \(\log \pi_{\theta_\text{old}}(a_t \mid s_t)\), we define:
Probability ratio:
\[r_t(\theta) = \exp\big( \log \pi_\theta(a_t \mid s_t) - \log \pi_{\theta_\text{old}}(a_t \mid s_t) \big)\]Clipped policy loss:
\[L^{\text{policy}}(\theta) = - \mathbb{E}_t \Big[ \min\big( r_t(\theta) A_t,\; \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \big) \Big]\]where \(\epsilon\) is the PPO clipping parameter.
Value-function loss (with clipping):
\[L^{\text{value}}(\theta) = \tfrac{1}{2} \mathbb{E}_t \Big[ \max\big( (V_\theta(s_t) - R_t)^2,\; (\text{clip}(V_\theta(s_t), V_{\theta_\text{old}}(s_t) - \epsilon, V_{\theta_\text{old}}(s_t) + \epsilon) - R_t)^2 \big) \Big]\]where \(R_t = A_t + r_t\) are return targets.
Entropy bonus:
\[L^{\text{entropy}}(\theta) = \mathbb{E}_t \big[ \mathcal{H}[\pi_\theta(\cdot \mid s_t)] \big]\]which encourages exploration.
Total loss:
\[L(\theta) = L^{\text{policy}}(\theta) + c_v L^{\text{value}}(\theta) - c_e L^{\text{entropy}}(\theta)\]where \(c_v\) and \(c_e\) are coefficients for the value and entropy terms.
Prioritized minibatch sampling and importance weighting
This trainer uses a prioritized categorical distribution over environments to form minibatches. For each environment index \(i \in \{1,\dots,N\}\), we define a priority from the trajectory advantages:
\[\tilde{p}_i \;=\; \Big\| A_{\cdot,i} \Big\|_1^{\,\alpha} \quad\text{with}\quad \Big\| A_{\cdot,i} \Big\|_1 \;=\; \sum_{t=1}^{T} \big|A_{t,i}\big|,\]where \(\alpha \ge 0\) (
importance_sampling_alpha
) controls the strength of prioritization. We then form a categorical sampling distribution\[P(i) \;=\; \frac{\tilde{p}_i}{\sum_{k=1}^{N} \tilde{p}_k},\]and sample indices \(\{i\}\) to create each minibatch (
jax.random.choice()
with probabilities \(P(i)\)). This mirrors Prioritized Experience Replay (PER), where \(\tilde{p}\) is derived from TD-error magnitude; here we use the per-environment advantage magnitude as a proxy for learning progress. This design is also inspired by recent large-scale self-play systems for autonomous driving.To correct sampling bias we apply PER-style importance weights (
importance_sampling_beta
with optional linear annealing):\[w_i(\beta_t) \;=\; \Big(N \, P(i)\Big)^{-\beta_t}, \qquad \beta_t \in [0,1].\]In classical PER, \(w_i\) is often normalized by \(\max_j w_j\) to keep the scale bounded; in this implementation we omit that normalization and use \(w_i\) directly. The minibatch advantages are standardized and reweighted with these IS weights before the PPO loss:
\[\hat{A}_{t,i} \;=\; w_i(\beta_t)\; \frac{A_{t,i} - \mu_{\text{mb}}(A)}{\sigma_{\text{mb}}(A)+\varepsilon}.\]Off-policy correction of advantages (V-trace)
After sampling, we recompute log-probabilities under the current policy (
td.new_log_prob = pi.log_prob(td.action)
) and compute targets/advantages with a V-trace–style off-policy correction incompute_advantages()
.— References
Schulman et al., Proximal Policy Optimization Algorithms, 2017.
Espeholt et al., IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures, ICML 2018.
Schulman et al., High-Dimensional Continuous Control Using Generalized Advantage Estimation, 2015/2016.
Schaul et al., Prioritized Experience Replay, ICLR 2016.
Cusumano-Towner et al., Robust Autonomy Emerges from Self-Play, ICML 2025.
- ppo_clip_eps: jax.Array#
PPO clipping parameter \(\epsilon\) used for both the policy ratio clip and (value-function clip.
- ppo_value_coeff: jax.Array#
Coefficient \(c_v\) scaling the value-function loss term in the total loss.
- ppo_entropy_coeff: jax.Array#
Coefficient \(c_e\) scaling the entropy bonus (encourages exploration).
- importance_sampling_alpha: jax.Array#
Prioritization strength \(\alpha \ge 0\) for minibatch sampling; higher values put more probability mass on envs with larger advantages.
- importance_sampling_beta: jax.Array#
Initial PER importance-weight exponent \(\beta \in [0,1]\) used in \(w_i(\beta) = (N P(i))^{-\beta}\); compensates sampling bias.
- anneal_importance_sampling_beta: jax.Array#
If nonzero/True, linearly anneals \(\beta\) toward 1 across training (more correction later in training).
- num_envs: int#
Number of vectorized environments \(N\) running in parallel.
- num_epochs: int#
Number of PPO training epochs (outer loop count).
- num_steps_epoch: int#
Rollout horizon \(T\) per epoch; total collected steps = \(N \times T\).
- num_minibatches: int#
Number of minibatches per epoch used for PPO updates.
- minibatch_size: int#
Minibatch size (number of env indices sampled per update); typically \(N / \text{num_minibatches}\).
- num_segments: int#
Number of vectorized environments times max number of agents.
- classmethod Create(env: Environment, model: Model, key: ArrayLike = Array((), dtype=key<fry>) overlaying: [0 1], learning_rate: float = 0.02, max_grad_norm: float = 0.4, ppo_clip_eps: float = 0.2, ppo_value_coeff: float = 0.5, ppo_entropy_coeff: float = 0.03, importance_sampling_alpha: float = 0.4, importance_sampling_beta: float = 0.1, advantage_gamma: float = 0.99, advantage_lambda: float = 0.95, advantage_rho_clip: float = 0.5, advantage_c_clip: float = 0.5, num_envs: int = 2048, num_epochs: int = 1000, num_steps_epoch: int = 64, num_minibatches: int = 10, minibatch_size: Optional[int] = None, accumulate_n_gradients: int = 1, clip_actions: bool = False, clip_range: Tuple[float, float] = (-0.2, 0.2), anneal_learning_rate: bool = True, learning_rate_decay_exponent: float = 2.0, learning_rate_decay_min_fraction: float = 0.01, anneal_importance_sampling_beta: bool = True, optimizer=<function muon>) Self [source][source]#
- static epoch(tr: PPOTrainer, epoch: Array | ndarray | bool | number | bool | int | float | complex)[source][source]#
Run one PPO training epoch.
- Returns:
Updated trainer state and the most recent time-stacked rollout.
- Return type:
- static loss_fn(model: Model, td: TrajectoryData, seg_w: jax.Array, advantage_rho_clip: jax.Array, advantage_c_clip: jax.Array, advantage_gamma: jax.Array, advantage_lambda: jax.Array, ppo_clip_eps: jax.Array, ppo_value_coeff: jax.Array, ppo_entropy_coeff: jax.Array, entropy_key: jax.Array)[source][source]#
Compute the PPO minibatch loss.
- Parameters:
model (Model) – Live model rebuilt by NNX for this step.
td (TrajectoryData) – Time-stacked trajectory mini batch (e.g., shape
[T, B, ...]
).seg_w (jax.Array) – Weights for advantage normalization.
- Returns:
Scalar loss to be minimized.
- Return type:
jax.Array
See also
Main
- static reset_model(tr: Trainer, shape: Sequence[int] | None = None, mask: Array | None = None) Trainer [source][source]#
Reset a model’s persistent recurrent state (e.g., LSTM carry) for all environments/agents and persist the mutation back into the trainer.
- Parameters:
tr (Trainer) – Trainer carrying the environment and NNX graph state. The target carry shape is inferred as
(tr.num_envs, tr.env.max_num_agents)
if not specified.mask (jax.Array, optional) – Boolean mask selecting which (env, agent) entries to reset. A value of
True
resets that entry. The mask may be shape(num_envs, num_agents)
or any shape broadcastable to it. IfNone
, all entries are reset.
- Returns:
A new trainer with the updated
graphstate
.- Return type:
- static train(tr: PPOTrainer, verbose=True)[source][source]#
Training loop
Modules
Implementation of PPO algorithm. |