jaxdem.rl.trainers.PPOtrainer#

Implementation of PPO algorithm.

Classes

PPOTrainer(env, graphdef, graphstate, key, ...)

Proximal Policy Optimization (PPO) trainer in PufferLib style.

class jaxdem.rl.trainers.PPOtrainer.PPOTrainer(env: Environment, graphdef: nnx.GraphDef[Any], graphstate: nnx.GraphState, key: ArrayLike, advantage_gamma: jax.Array, advantage_lambda: jax.Array, advantage_rho_clip: jax.Array, advantage_c_clip: jax.Array, drip_decay: jax.Array, ppo_clip_eps: jax.Array, ppo_value_coeff: jax.Array, ppo_entropy_coeff: jax.Array, importance_sampling_alpha: jax.Array, importance_sampling_beta: jax.Array, anneal_importance_sampling_beta: jax.Array, num_epochs: int, stop_at_epoch: int, num_steps_epoch: int, num_minibatches: int, minibatch_size: int, skip_frames: int)#

Bases: Trainer

Proximal Policy Optimization (PPO) trainer in PufferLib style.

This trainer implements the PPO algorithm with clipped surrogate objectives, value-function loss, entropy regularization, and prioritized experience replay (PER).

Loss function

Given a trajectory batch with actions \(a_t\), states \(s_t\), rewards \(r_t\), advantages \(A_t\), and old log-probabilities \(\log \pi_{\theta_\text{old}}(a_t \mid s_t)\), we define:

  • Probability ratio:

    \[\rho_t(\theta) = \exp\big( \log \pi_\theta(a_t \mid s_t) - \log \pi_{\theta_\text{old}}(a_t \mid s_t) \big)\]
  • Clipped policy loss:

    \[L^{\text{policy}}(\theta) = - \mathbb{E}_t \Big[ \min\big( \rho_t(\theta) A_t,\; \text{clip}(\rho_t(\theta), 1-\epsilon, 1+\epsilon) A_t \big) \Big]\]

    where \(\epsilon\) is the PPO clipping parameter.

  • Value-function loss (with clipping):

    \[L^{\text{value}}(\theta) = \tfrac{1}{2} \mathbb{E}_t \Big[ \max\big( (V_\theta(s_t) - R_t)^2,\; (\text{clip}(V_\theta(s_t), V_{\theta_\text{old}}(s_t) - \epsilon, V_{\theta_\text{old}}(s_t) + \epsilon) - R_t)^2 \big) \Big]\]

    where \(R_t = A_t + V_{\theta_\text{old}}(s_t)\) are return targets.

  • Entropy bonus:

    \[L^{\text{entropy}}(\theta) = \mathbb{E}_t \big[ \mathcal{H}[\pi_\theta(\cdot \mid s_t)] \big]\]

    which encourages exploration.

  • Total loss:

    \[L(\theta) = L^{\text{policy}}(\theta) + c_v L^{\text{value}}(\theta) - c_e L^{\text{entropy}}(\theta)\]

    where \(c_v\) and \(c_e\) are coefficients for the value and entropy terms.

Prioritized Experience Replay (PER)

This trainer uses a prioritized categorical distribution over segments (environments x agents) to form minibatches. For each segment index \(i \in \{1,\dots,N\}\), we define a priority from the trajectory advantages:

\[\tilde{p}_i \;=\; \Big\| A_{\cdot,i} \Big\|_1^{\,\alpha} \quad\text{with}\quad \Big\| A_{\cdot,i} \Big\|_1 \;=\; \sum_{t=1}^{T} \big|A_{t,i}\big|,\]

where \(\alpha \ge 0\) (importance_sampling_alpha) controls the strength of prioritization. We then form a categorical sampling distribution

\[P(i) \;=\; \frac{\tilde{p}_i}{\sum_{k=1}^{N} \tilde{p}_k},\]

and sample indices \(\{i\}\) to create each minibatch (jax.random.choice() with probabilities \(P(i)\)). This mirrors Prioritized Experience Replay (PER), where \(\tilde{p}\) is derived from TD-error magnitude; here we use the per-trajectory advantage magnitude as a proxy for learning progress. Recent large-scale self-play systems for autonomous driving also inspire this design. We use the absolute value of the advantage such that we include the best and worst samples. Learning from mistakes is also a great way to learn!

To correct sampling bias we apply PER-style importance weights (importance_sampling_beta with optional linear annealing):

\[w_i(\beta_t) \;=\; \Big(N \, P(i)\Big)^{-\beta_t}, \qquad \beta_t \in [0,1].\]

In classical PER, \(w_i\) is often normalized by \(\max_j w_j\) to keep the scale bounded; in this implementation we omit that normalization and use \(w_i\) directly. The minibatch advantages are standardized and reweighted with these IS weights before the PPO loss:

\[\hat{A}_{t,i} \;=\; w_i(\beta_t)\; \frac{A_{t,i} - \mu_{\text{mb}}(A)}{\sigma_{\text{mb}}(A)+\varepsilon}.\]

If importance_sampling_alpha = 0, we get uniform sampling. If importance_sampling_beta = 1 we get full PER correction.

Off-policy correction of advantages (V-trace)

We recompute advantage on each minibatch iteration, making sure to update the value and the ratio of the distribution probabilities. This way, if we end up reusing a sample, V-trace off-policy correction is used to compute the advantages (Trainer.compute_advantages()). This is important as the policy keeps evolving during each minibatch, making the rollout off-policy and the value stale.

Distributed Reward Information Processing (DRIP)

DRIP is a technique to solve the credit assignment problem in environments with sparse or delayed feedback. It applies a recursive backward pass over the trajectory to distribute terminal or delayed rewards backward to the past causal states. This is implemented via an exponential smoothing filter strictly bounded by episode terminations to prevent cross-episode bleeding. - To activate: Set drip_decay to a value between (0.0, 1.0] (e.g., 0.8). - To deactivate: Set drip_decay to 0.0 (the default behavior).

References

  • Schulman et al., Proximal Policy Optimization Algorithms, 2017.

  • Espeholt et al., IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures, ICML 2018.

  • Schulman et al., High-Dimensional Continuous Control Using Generalized Advantage Estimation, 2015/2016.

  • Schaul et al., Prioritized Experience Replay, ICLR 2016.

  • Cusumano-Towner et al., Robust Autonomy Emerges from Self-Play, ICML 2025.

drip_decay: jax.Array#

Decay factor \(\lambda_{DRIP}\) for Distributed Reward Information Processing (DRIP). Drips delayed/sparse rewards backward through time to assign credit to past actions. Set to 0.0 to disable (default).

ppo_clip_eps: jax.Array#

PPO clipping parameter \(\epsilon\) used for both the policy ratio clip and (value-function clip.

ppo_value_coeff: jax.Array#

Coefficient \(c_v\) scaling the value-function loss term in the total loss.

ppo_entropy_coeff: jax.Array#

Coefficient \(c_e\) scaling the entropy bonus (encourages exploration).

importance_sampling_alpha: jax.Array#

Prioritization strength \(\alpha \ge 0\) for minibatch sampling; higher values put more probability mass on envs with larger advantages.

importance_sampling_beta: jax.Array#

Initial PER importance-weight exponent \(\beta \in [0,1]\) used in \(w_i(\beta) = (N P(i))^{-\beta}\); compensates sampling bias.

anneal_importance_sampling_beta: jax.Array#

If nonzero/True, linearly anneals \(\beta\) toward 1 across training (more correction later in training).

num_epochs: int#

Number of PPO training epochs (outer loop count).

stop_at_epoch: int#

Stop after this epoch. Must satisfy 1 ≤ stop_at_epoch ≤ num_epochs.

num_steps_epoch: int#

Rollout horizon \(T\) per epoch; total collected steps = \(N \times T\).

num_minibatches: int#

Number of minibatches per epoch used for PPO updates.

minibatch_size: int#

Minibatch size (number of env indices sampled per update); typically \(N / \text{num_minibatches}\).

skip_frames: int#

Number of frames to skip (repeat action) for each observation.

classmethod Create(env: Environment, model: Model, seed: int | None = None, key: ArrayLike = Array((), dtype=key<fry>) overlaying: [0 1], optimizer: Any = <function muon>, learning_rate: float = 0.01, anneal_learning_rate: bool = True, max_grad_norm: float = 1.5, accumulate_n_gradients: int = 1, ppo_clip_eps: float = 0.2, ppo_value_coeff: float = 2.0, ppo_entropy_coeff: float = 0.001, advantage_gamma: float = 0.99, advantage_lambda: float = 0.95, advantage_rho_clip: float = 1.0, advantage_c_clip: float = 1.0, drip_decay: float = 0.0, importance_sampling_alpha: float = 0.8, importance_sampling_beta: float = 0.2, anneal_importance_sampling_beta: bool = True, num_envs: int = 1024, num_steps_epoch: int = 64, num_minibatches: int = 4, minibatch_size: int | None = None, skip_frames: int = 0, num_epochs: int = 1000, total_timesteps: int | None = None, stop_at_epoch: int | None = None, clip_actions: bool = False, clip_range: tuple[float, float] = (-0.2, 0.2)) Self[source]#

Construct a PPO trainer from an environment and a model.

Vectorises the environment, builds the optimizer chain, and initialises the model carry. See the class-level field docstrings for parameter descriptions.

Parameters:
  • env (Environment) – A single (non-vectorised) environment instance.

  • model (Model) – An actor–critic model whose observation_space_size and action_space_size match env.

Returns:

Ready-to-train trainer instance.

Return type:

PPOTrainer

static one_epoch(tr: PPOTrainer, epoch: Array) tuple[PPOTrainer, TrajectoryData, dict[str, Any]][source]#

Run one training epoch (delegates to epoch()).

Parameters:
  • tr (PPOTrainer) – Current trainer state.

  • epoch (jax.Array) – Zero-based epoch index (scalar integer).

Returns:

Updated trainer, trajectory data, and scalar metrics.

Return type:

tuple[PPOTrainer, TrajectoryData, dict[str, Any]]

static train(tr: Trainer, verbose: bool = True, log: bool = True, directory: Path | str = 'runs', save_every: int = 2, start_epoch: int = 0, **kwargs: Any) PPOTrainer[source]#

Run the full PPO training loop.

Parameters:
  • tr (Trainer) – Trainer instance (will be cast to PPOTrainer).

  • verbose (bool) – If True, display a tqdm progress bar.

  • log (bool) – If True, write TensorBoard scalars to directory.

  • directory (Path | str) – Root directory for TensorBoard logs.

  • save_every (int) – Sync metrics and log every save_every epochs.

  • start_epoch (int) – Resume epoch counter (useful after checkpoint restore).

Returns:

Trainer with updated parameters after training.

Return type:

PPOTrainer

static loss_fn(model: Model, td: TrajectoryData, returns: jax.Array, advantage: jax.Array, ppo_clip_eps: jax.Array, ppo_value_coeff: jax.Array, ppo_entropy_coeff: jax.Array) tuple[jax.Array, dict[str, jax.Array]][source]#

Compute the clipped PPO loss for a minibatch.

Runs a forward pass through model and returns the composite loss (policy + value + entropy) together with diagnostic scalars.

Parameters:
  • model (Model) – Actor–critic model (called with sequence=True).

  • td (TrajectoryData) – Minibatch trajectory slice [T, M, ...].

  • returns (jax.Array) – Return targets, shape [T, M].

  • advantage (jax.Array) – Normalised, IS-weighted advantages, shape [T, M].

  • ppo_clip_eps (jax.Array) – Clipping parameter \(\epsilon\).

  • ppo_value_coeff (jax.Array) – Value-loss coefficient \(c_v\).

  • ppo_entropy_coeff (jax.Array) – Entropy-bonus coefficient \(c_e\).

Returns:

Scalar total loss and a dictionary of diagnostic metrics.

Return type:

tuple[jax.Array, dict[str, jax.Array]]

static epoch(tr: PPOTrainer, epoch: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) tuple[PPOTrainer, TrajectoryData, dict[str, Array]][source]#

Execute one PPO epoch: rollout → advantage → minibatch updates.

Steps: 0. Reset done environments and split PRNG keys. 1. Collect a trajectory of length num_steps_epoch. 2. Flatten the agent axis and apply DRIP if enabled. 3. Compute PER priorities. 4. Scan over num_minibatches updates, each recomputing V-trace advantages and applying the clipped PPO loss.

Parameters:
  • tr (PPOTrainer) – Current trainer state.

  • epoch (ArrayLike) – Zero-based epoch index (used for \(\beta\) annealing).

Returns:

Updated trainer, full trajectory data, and epoch-averaged metrics.

Return type:

Tuple[PPOTrainer, TrajectoryData, dict[str, jax.Array]]