jaxdem.rl.trainers.PPOtrainer#

Implementation of PPO algorithm.

Classes

PPOTrainer(env, graphdef, graphstate, key, ...)

Proximal Policy Optimization (PPO) trainer in PufferLib style.

class jaxdem.rl.trainers.PPOtrainer.PPOTrainer(env: Environment, graphdef: nnx.GraphDef, graphstate: nnx.GraphState, key: ArrayLike, advantage_gamma: jax.Array, advantage_lambda: jax.Array, advantage_rho_clip: jax.Array, advantage_c_clip: jax.Array, ppo_clip_eps: jax.Array, ppo_value_coeff: jax.Array, ppo_entropy_coeff: jax.Array, importance_sampling_alpha: jax.Array, importance_sampling_beta: jax.Array, anneal_importance_sampling_beta: jax.Array, num_epochs: int, stop_at_epoch: int, num_steps_epoch: int, num_minibatches: int, minibatch_size: int)[source]#

Bases: Trainer

Proximal Policy Optimization (PPO) trainer in PufferLib style.

This trainer implements the PPO algorithm with clipped surrogate objectives, value-function loss, entropy regularization, and prioritized experience replay (PER).

Loss function

Given a trajectory batch with actions \(a_t\), states \(s_t\), rewards \(r_t\), advantages \(A_t\), and old log-probabilities \(\log \pi_{\theta_\text{old}}(a_t \mid s_t)\), we define:

  • Probability ratio:

    \[\rho_t(\theta) = \exp\big( \log \pi_\theta(a_t \mid s_t) - \log \pi_{\theta_\text{old}}(a_t \mid s_t) \big)\]
  • Clipped policy loss:

    \[L^{\text{policy}}(\theta) = - \mathbb{E}_t \Big[ \min\big( \rho_t(\theta) A_t,\; \text{clip}(\rho_t(\theta), 1-\epsilon, 1+\epsilon) A_t \big) \Big]\]

    where \(\epsilon\) is the PPO clipping parameter.

  • Value-function loss (with clipping):

    \[L^{\text{value}}(\theta) = \tfrac{1}{2} \mathbb{E}_t \Big[ \max\big( (V_\theta(s_t) - R_t)^2,\; (\text{clip}(V_\theta(s_t), V_{\theta_\text{old}}(s_t) - \epsilon, V_{\theta_\text{old}}(s_t) + \epsilon) - R_t)^2 \big) \Big]\]

    where \(R_t = A_t + V_{\theta_\text{old}}(s_t)\) are return targets.

  • Entropy bonus:

    \[L^{\text{entropy}}(\theta) = \mathbb{E}_t \big[ \mathcal{H}[\pi_\theta(\cdot \mid s_t)] \big]\]

    which encourages exploration.

  • Total loss:

    \[L(\theta) = L^{\text{policy}}(\theta) + c_v L^{\text{value}}(\theta) - c_e L^{\text{entropy}}(\theta)\]

    where \(c_v\) and \(c_e\) are coefficients for the value and entropy terms.

Prioritized Experience Replay (PER)

This trainer uses a prioritized categorical distribution over segments (environments x agents) to form minibatches. For each segment index \(i \in \{1,\dots,N\}\), we define a priority from the trajectory advantages:

\[\tilde{p}_i \;=\; \Big\| A_{\cdot,i} \Big\|_1^{\,\alpha} \quad\text{with}\quad \Big\| A_{\cdot,i} \Big\|_1 \;=\; \sum_{t=1}^{T} \big|A_{t,i}\big|,\]

where \(\alpha \ge 0\) (importance_sampling_alpha) controls the strength of prioritization. We then form a categorical sampling distribution

\[P(i) \;=\; \frac{\tilde{p}_i}{\sum_{k=1}^{N} \tilde{p}_k},\]

and sample indices \(\{i\}\) to create each minibatch (jax.random.choice() with probabilities \(P(i)\)). This mirrors Prioritized Experience Replay (PER), where \(\tilde{p}\) is derived from TD-error magnitude; here we use the per-trajectory advantage magnitude as a proxy for learning progress. Recent large-scale self-play systems for autonomous driving also inspire this design. We use the absolute value of the advantage such that we include the best and worst samples. Learning from mistakes is also a great way to learn!

To correct sampling bias we apply PER-style importance weights (importance_sampling_beta with optional linear annealing):

\[w_i(\beta_t) \;=\; \Big(N \, P(i)\Big)^{-\beta_t}, \qquad \beta_t \in [0,1].\]

In classical PER, \(w_i\) is often normalized by \(\max_j w_j\) to keep the scale bounded; in this implementation we omit that normalization and use \(w_i\) directly. The minibatch advantages are standardized and reweighted with these IS weights before the PPO loss:

\[\hat{A}_{t,i} \;=\; w_i(\beta_t)\; \frac{A_{t,i} - \mu_{\text{mb}}(A)}{\sigma_{\text{mb}}(A)+\varepsilon}.\]

If importance_sampling_alpha = 0, we get uniform sampling. If importance_sampling_beta = 1 we get full PER correction.

Off-policy correction of advantages (V-trace)

We recompute advantage on each minibatch iteration, making sure to update the value and the ratio of the distribution probabilities. This way, if we end up reusing a sample, V-trace off-policy correction is used to compute the advantages (Trainer.compute_advantages()). This is important as the policy keeps evolving during each minibatch, making the rollout off-policy and the value stale.

References

  • Schulman et al., Proximal Policy Optimization Algorithms, 2017.

  • Espeholt et al., IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures, ICML 2018.

  • Schulman et al., High-Dimensional Continuous Control Using Generalized Advantage Estimation, 2015/2016.

  • Schaul et al., Prioritized Experience Replay, ICLR 2016.

  • Cusumano-Towner et al., Robust Autonomy Emerges from Self-Play, ICML 2025.

ppo_clip_eps: jax.Array#

PPO clipping parameter \(\epsilon\) used for both the policy ratio clip and (value-function clip.

ppo_value_coeff: jax.Array#

Coefficient \(c_v\) scaling the value-function loss term in the total loss.

ppo_entropy_coeff: jax.Array#

Coefficient \(c_e\) scaling the entropy bonus (encourages exploration).

importance_sampling_alpha: jax.Array#

Prioritization strength \(\alpha \ge 0\) for minibatch sampling; higher values put more probability mass on envs with larger advantages.

importance_sampling_beta: jax.Array#

Initial PER importance-weight exponent \(\beta \in [0,1]\) used in \(w_i(\beta) = (N P(i))^{-\beta}\); compensates sampling bias.

anneal_importance_sampling_beta: jax.Array#

If nonzero/True, linearly anneals \(\beta\) toward 1 across training (more correction later in training).

num_epochs: int#

Number of PPO training epochs (outer loop count).

stop_at_epoch: int#

Stop after this epoch. Must satisfy 1 ≤ stop_at_epoch ≤ num_epochs.

num_steps_epoch: int#

Rollout horizon \(T\) per epoch; total collected steps = \(N \times T\).

num_minibatches: int#

Number of minibatches per epoch used for PPO updates.

minibatch_size: int#

Minibatch size (number of env indices sampled per update); typically \(N / \text{num_minibatches}\).

classmethod Create(env: Environment, model: Model, seed: Optional[int] = None, key: ArrayLike = Array((), dtype=key<fry>) overlaying: [0 1], optimizer=<function muon>, learning_rate: float = 0.01, anneal_learning_rate: bool = True, max_grad_norm: float = 1.5, accumulate_n_gradients: int = 1, ppo_clip_eps: float = 0.2, ppo_value_coeff: float = 2.0, ppo_entropy_coeff: float = 0.001, advantage_gamma: float = 0.99, advantage_lambda: float = 0.95, advantage_rho_clip: float = 1.0, advantage_c_clip: float = 1.0, importance_sampling_alpha: float = 0.8, importance_sampling_beta: float = 0.2, anneal_importance_sampling_beta: bool = True, num_envs: int = 1024, num_steps_epoch: int = 64, num_minibatches: int = 4, minibatch_size: Optional[int] = None, num_epochs: int = 1000, total_timesteps: Optional[int] = None, stop_at_epoch: Optional[int] = None, clip_actions: bool = False, clip_range: Tuple[float, float] = (-0.2, 0.2)) Self[source][source]#
static one_epoch(tr: PPOTrainer, epoch)[source][source]#
static train(tr: PPOTrainer, verbose: bool = True, log: bool = True, directory: Path | str = 'runs', save_every: int = 2, start_epoch: int = 0)[source][source]#

Training loop

Subclasses must implement this with their algorithm-specific logic.

static loss_fn(model: Model, td: TrajectoryData, returns: jax.Array, advantage: jax.Array, ppo_clip_eps: jax.Array, ppo_value_coeff: jax.Array, ppo_entropy_coeff: jax.Array)[source][source]#
static epoch(tr: PPOTrainer, epoch: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray)[source][source]#