# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Implementation of PPO algorithm.
"""
from __future__ import annotations
import jax
import jax.numpy as jnp
from jax.typing import ArrayLike
from typing import TYPE_CHECKING, Tuple, Optional, Sequence, cast
try:
# Python 3.11+
from typing import Self
except ImportError:
from typing_extensions import Self
from dataclasses import dataclass, field, replace
from functools import partial
import time
import datetime
from flax import nnx
from flax.metrics import tensorboard
import optax
from tqdm.auto import trange
from . import Trainer, TrajectoryData
from ..envWrappers import clip_action_env, vectorise_env
if TYPE_CHECKING:
from ..environments import Environment
from ..models import Model
[docs]
@Trainer.register("PPO")
@jax.tree_util.register_dataclass
@dataclass(slots=True, frozen=True)
class PPOTrainer(Trainer):
r"""
Proximal Policy Optimization (PPO) trainer in `PufferLib <https://github.com/PufferAI/PufferLib>`_ style.
This trainer implements the PPO algorithm with
clipped surrogate objectives, value-function loss, entropy regularization,
and importance-sampling reweighting.
**Loss function**
Given a trajectory batch with actions :math:`a_t`, states :math:`s_t`,
rewards :math:`r_t`, advantages :math:`A_t`, and old log-probabilities
:math:`\log \pi_{\theta_\text{old}}(a_t \mid s_t)`, we define:
- **Probability ratio**:
.. math::
r_t(\theta) = \exp\big( \log \pi_\theta(a_t \mid s_t) -
\log \pi_{\theta_\text{old}}(a_t \mid s_t) \big)
- **Clipped policy loss**:
.. math::
L^{\text{policy}}(\theta) =
- \mathbb{E}_t \Big[ \min\big( r_t(\theta) A_t,\;
\text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \big) \Big]
where :math:`\epsilon` is the PPO clipping parameter.
- **Value-function loss (with clipping)**:
.. math::
L^{\text{value}}(\theta) =
\tfrac{1}{2} \mathbb{E}_t \Big[ \max\big( (V_\theta(s_t) - R_t)^2,\;
(\text{clip}(V_\theta(s_t), V_{\theta_\text{old}}(s_t) - \epsilon,
V_{\theta_\text{old}}(s_t) + \epsilon) - R_t)^2 \big) \Big]
where :math:`R_t = A_t + r_t` are return targets.
- **Entropy bonus**:
.. math::
L^{\text{entropy}}(\theta) = \mathbb{E}_t \big[ \mathcal{H}[\pi_\theta(\cdot \mid s_t)] \big]
which encourages exploration.
- **Total loss**:
.. math::
L(\theta) = L^{\text{policy}}(\theta)
+ c_v L^{\text{value}}(\theta)
- c_e L^{\text{entropy}}(\theta)
where :math:`c_v` and :math:`c_e` are coefficients for the value and entropy terms.
**Prioritized minibatch sampling and importance weighting**
This trainer uses a prioritized categorical distribution over environments to
form minibatches. For each environment index :math:`i \in \{1,\dots,N\}`,
we define a *priority* from the trajectory advantages:
.. math::
\tilde{p}_i \;=\; \Big\| A_{\cdot,i} \Big\|_1^{\,\alpha}
\quad\text{with}\quad
\Big\| A_{\cdot,i} \Big\|_1 \;=\; \sum_{t=1}^{T} \big|A_{t,i}\big|,
where :math:`\alpha \ge 0` (:attr:`importance_sampling_alpha`) controls the
strength of prioritization. We then form a categorical sampling distribution
.. math::
P(i) \;=\; \frac{\tilde{p}_i}{\sum_{k=1}^{N} \tilde{p}_k},
and sample indices :math:`\{i\}` to create each minibatch
(:func:`jax.random.choice` with probabilities :math:`P(i)`).
This mirrors Prioritized Experience Replay (PER), where :math:`\tilde{p}` is
derived from TD-error magnitude; here we use the per-environment advantage
magnitude as a proxy for learning progress. This design is also inspired by
recent large-scale self-play systems for autonomous driving.
To correct sampling bias we apply PER-style importance weights
(:attr:`importance_sampling_beta` with optional linear annealing):
.. math::
w_i(\beta_t) \;=\; \Big(N \, P(i)\Big)^{-\beta_t},
\qquad \beta_t \in [0,1].
In classical PER, :math:`w_i` is often normalized by :math:`\max_j w_j` to keep
the scale bounded; in this implementation we omit that normalization and use
:math:`w_i` directly. The minibatch advantages are standardized and *reweighted*
with these IS weights before the PPO loss:
.. math::
\hat{A}_{t,i}
\;=\;
w_i(\beta_t)\;
\frac{A_{t,i} - \mu_{\text{mb}}(A)}{\sigma_{\text{mb}}(A)+\varepsilon}.
**Off-policy correction of advantages (V-trace)**
After sampling, we recompute log-probabilities under the *current* policy
(:code:`td.new_log_prob = pi.log_prob(td.action)`) and compute
targets/advantages with a V-trace–style off-policy correction in
:meth:`compute_advantages`.
---
**References**
- Schulman et al., *Proximal Policy Optimization Algorithms*, 2017.
- Espeholt et al., *IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures*, ICML 2018.
- Schulman et al., *High-Dimensional Continuous Control Using Generalized Advantage Estimation*, 2015/2016.
- Schaul et al., *Prioritized Experience Replay*, ICLR 2016.
- Cusumano-Towner et al., *Robust Autonomy Emerges from Self-Play*, ICML 2025.
"""
ppo_clip_eps: jax.Array
r"""
PPO clipping parameter :math:`\epsilon` used for both the policy ratio clip
and (value-function clip.
"""
ppo_value_coeff: jax.Array
r"""
Coefficient :math:`c_v` scaling the value-function loss term in the total loss.
"""
ppo_entropy_coeff: jax.Array
r"""
Coefficient :math:`c_e` scaling the entropy bonus (encourages exploration).
"""
importance_sampling_alpha: jax.Array
r"""
Prioritization strength :math:`\alpha \ge 0` for minibatch sampling;
higher values put more probability mass on envs with larger advantages.
"""
importance_sampling_beta: jax.Array
r"""
Initial PER importance-weight exponent :math:`\beta \in [0,1]` used in
:math:`w_i(\beta) = (N P(i))^{-\beta}`; compensates sampling bias.
"""
anneal_importance_sampling_beta: jax.Array
r"""
If nonzero/True, linearly anneals :math:`\beta` toward 1 across training
(more correction later in training).
"""
num_envs: int = field(default=4096, metadata={"static": True})
r"""
Number of vectorized environments :math:`N` running in parallel.
"""
num_epochs: int = field(default=2000, metadata={"static": True})
"""
Number of PPO training epochs (outer loop count).
"""
num_steps_epoch: int = field(default=128, metadata={"static": True})
r"""
Rollout horizon :math:`T` per epoch; total collected steps = :math:`N \times T`.
"""
num_minibatches: int = field(default=8, metadata={"static": True})
"""
Number of minibatches per epoch used for PPO updates.
"""
minibatch_size: int = field(default=512, metadata={"static": True})
r"""
Minibatch size (number of env indices sampled per update); typically
:math:`N / \text{num_minibatches}`.
"""
num_segments: int = field(default=4096, metadata={"static": True})
r"""
Number of vectorized environments times max number of agents.
"""
[docs]
@classmethod
def Create(
cls,
env: "Environment",
model: "Model",
key: ArrayLike = jax.random.key(1),
learning_rate: float = 2e-2,
max_grad_norm: float = 0.4,
ppo_clip_eps: float = 0.2,
ppo_value_coeff: float = 0.5,
ppo_entropy_coeff: float = 0.03,
importance_sampling_alpha: float = 0.4,
importance_sampling_beta: float = 0.1,
advantage_gamma: float = 0.99,
advantage_lambda: float = 0.95,
advantage_rho_clip: float = 0.5,
advantage_c_clip: float = 0.5,
num_envs: int = 2048,
num_epochs: int = 1000,
num_steps_epoch: int = 64,
num_minibatches: int = 10,
minibatch_size: Optional[int] = None,
accumulate_n_gradients: int = 1, # only use for memory savings, bad performance
clip_actions: bool = False,
clip_range: Tuple[float, float] = (-0.2, 0.2),
anneal_learning_rate: bool = True,
learning_rate_decay_exponent: float = 2.0,
learning_rate_decay_min_fraction: float = 0.01,
anneal_importance_sampling_beta: bool = True,
optimizer=optax.contrib.muon,
) -> Self:
key, subkeys = jax.random.split(key)
subkeys = jax.random.split(subkeys, num_envs)
num_epochs = int(num_epochs)
if anneal_learning_rate:
schedule = optax.cosine_decay_schedule(
init_value=float(learning_rate),
alpha=float(learning_rate_decay_min_fraction),
decay_steps=int(num_epochs),
exponent=float(learning_rate_decay_exponent),
)
else:
schedule = jnp.asarray(learning_rate, dtype=float)
tx = optax.chain(
optax.clip_by_global_norm(float(max_grad_norm)),
optimizer(schedule, eps=1e-5),
optax.apply_every(int(accumulate_n_gradients)),
)
metrics = nnx.MultiMetric(
score=nnx.metrics.Average(argname="score"),
loss=nnx.metrics.Average(argname="loss"),
actor_loss=nnx.metrics.Average(argname="actor_loss"),
value_loss=nnx.metrics.Average(argname="value_loss"),
entropy=nnx.metrics.Average(argname="entropy"),
approx_KL=nnx.metrics.Average(argname="approx_KL"),
returns=nnx.metrics.Average(argname="returns"),
ratio=nnx.metrics.Average(argname="ratio"),
policy_std=nnx.metrics.Average(argname="policy_std"),
explained_variance=nnx.metrics.Average(argname="explained_variance"),
grad_norm=nnx.metrics.Average(argname="grad_norm"),
)
graphdef, graphstate = nnx.split(
(model, nnx.Optimizer(model, tx, wrt=nnx.Param), metrics)
)
num_envs = int(num_envs)
env = jax.vmap(lambda _: env)(jnp.arange(num_envs))
if clip_actions:
min_val, max_val = clip_range
env = clip_action_env(env, min_val=min_val, max_val=max_val)
env = vectorise_env(env)
env = env.reset(env, subkeys)
num_segments = int(num_envs * env.max_num_agents)
num_minibatches = int(num_minibatches)
if minibatch_size is None:
minibatch_size = num_segments // num_minibatches
minibatch_size = int(minibatch_size)
assert (
minibatch_size <= num_segments
), f"minibatch_size = {minibatch_size} is larger than num_envs * max_num_agents={num_segments}."
model, optimizer, *rest = nnx.merge(graphdef, graphstate)
model.reset(shape=(num_envs, env.max_num_agents))
graphstate = nnx.state((model, optimizer, *rest))
return cls(
key=key,
env=env,
graphdef=graphdef,
graphstate=graphstate,
advantage_gamma=jnp.asarray(advantage_gamma, dtype=float),
advantage_lambda=jnp.asarray(advantage_lambda, dtype=float),
advantage_rho_clip=jnp.asarray(advantage_rho_clip, dtype=float),
advantage_c_clip=jnp.asarray(advantage_c_clip, dtype=float),
ppo_clip_eps=jnp.asarray(ppo_clip_eps, dtype=float),
ppo_value_coeff=jnp.asarray(ppo_value_coeff, dtype=float),
ppo_entropy_coeff=jnp.asarray(ppo_entropy_coeff, dtype=float),
importance_sampling_alpha=jnp.asarray(
importance_sampling_alpha, dtype=float
),
importance_sampling_beta=jnp.asarray(importance_sampling_beta, dtype=float),
anneal_importance_sampling_beta=jnp.asarray(
anneal_importance_sampling_beta, dtype=float
),
num_envs=int(num_envs),
num_epochs=int(num_epochs),
num_steps_epoch=int(num_steps_epoch),
num_minibatches=int(num_minibatches),
minibatch_size=int(minibatch_size),
num_segments=int(num_segments),
)
[docs]
@staticmethod
def reset_model(
tr: "Trainer",
shape: Sequence[int] | None = None,
mask: jax.Array | None = None,
) -> "Trainer":
"""
Reset a model's persistent recurrent state (e.g., LSTM carry) for all
environments/agents and persist the mutation back into the trainer.
Parameters
----------
tr : Trainer
Trainer carrying the environment and NNX graph state. The target carry
shape is inferred as ``(tr.num_envs, tr.env.max_num_agents)`` if not specified.
mask : jax.Array, optional
Boolean mask selecting which (env, agent) entries to reset. A value of
``True`` resets that entry. The mask may be shape
``(num_envs, num_agents)`` or any shape broadcastable to it. If
``None``, all entries are reset.
Returns
-------
Trainer
A new trainer with the updated ``graphstate``.
"""
tr = cast("PPOTrainer", tr)
if shape is None:
shape = (tr.num_envs, tr.env.max_num_agents)
model, optimizer, *rest = nnx.merge(tr.graphdef, tr.graphstate)
model.reset(shape=shape, mask=mask)
return replace(
tr,
graphstate=nnx.state((model, optimizer, *rest)),
)
[docs]
@staticmethod
def train(tr: "PPOTrainer", verbose=True):
metrics_history = []
tr, _ = tr.epoch(tr, jnp.asarray(0))
it = trange(1, tr.num_epochs) if verbose else range(1, tr.num_epochs)
start_time = time.perf_counter()
if verbose:
log_folder = "runs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
writer = tensorboard.SummaryWriter(log_folder)
for epoch in it:
model, optimizer, metrics, *rest = nnx.merge(tr.graphdef, tr.graphstate)
metrics.reset()
tr = replace(tr, graphstate=nnx.state((model, optimizer, metrics, *rest)))
tr, _ = tr.epoch(tr, jnp.asarray(epoch))
model, optimizer, metrics, *rest = nnx.merge(tr.graphdef, tr.graphstate)
data = metrics.compute()
data["elapsed"] = time.perf_counter() - start_time
steps_done = (epoch + 1) * tr.num_envs * tr.num_steps_epoch
data["steps_per_sec"] = steps_done / data["elapsed"]
if verbose:
it.set_postfix(
{
"steps/s": f"{data['steps_per_sec']:.2e}",
"avg_score": f"{data['score']:.2f}",
}
)
for k, v in data.items():
writer.scalar(k, v, step=int(epoch))
writer.flush()
print(
f"steps/s: {data['steps_per_sec']:.2e}, final avg_score: {data['score']:.2f}"
)
return tr, metrics_history
[docs]
@staticmethod
@nnx.jit(donate_argnames=("td", "seg_w"))
def loss_fn(
model: "Model",
td: "TrajectoryData",
seg_w: jax.Array,
advantage_rho_clip: jax.Array,
advantage_c_clip: jax.Array,
advantage_gamma: jax.Array,
advantage_lambda: jax.Array,
ppo_clip_eps: jax.Array,
ppo_value_coeff: jax.Array,
ppo_entropy_coeff: jax.Array,
entropy_key: jax.Array,
):
"""
Compute the PPO minibatch loss.
Parameters
----------
model : Model
Live model rebuilt by NNX for this step.
td : TrajectoryData
Time-stacked trajectory mini batch (e.g., shape ``[T, B, ...]``).
seg_w : jax.Array
Weights for advantage normalization.
Returns
-------
jax.Array
Scalar loss to be minimized.
See Also
--------
Main docs: PPO trainer overview and equations.
"""
# 1) Fordward pass
old_value = td.value
model.eval()
pi, value = model(td.obs)
td = replace(
td,
new_log_prob=pi.log_prob(td.action),
value=jnp.squeeze(value, -1),
)
# 2) Recompute advantages and normalize
td = PPOTrainer.compute_advantages(
td,
advantage_rho_clip,
advantage_c_clip,
advantage_gamma,
advantage_lambda,
)
td.advantage = (
(td.advantage - td.advantage.mean()) / (td.advantage.std() + 1e-8) * seg_w
)
td.advantage = jax.lax.stop_gradient(td.advantage) # for policy loss
td.returns = jax.lax.stop_gradient(td.returns) # for value loss
# 3) Value loss (clipped)
value_pred_clipped = old_value + (td.value - old_value).clip(
-ppo_clip_eps, ppo_clip_eps
)
value_losses = jnp.square(td.value - td.returns)
value_losses_clipped = jnp.square(value_pred_clipped - td.returns)
value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
# 4) Policy loss (clipped)
log_ratio = td.new_log_prob - td.log_prob
ratio = jnp.exp(log_ratio)
actor_loss = -jnp.minimum(
td.advantage * ratio,
td.advantage * ratio.clip(1.0 - ppo_clip_eps, 1.0 + ppo_clip_eps),
).mean()
# 5) Estimate Entropy (Entropy is not available for distributions transformed by bijectors with non-constant Jacobian determinant)
# H[π]=E_{a∼π}[−log π(a)]≈−1/K ∑_{k=1}^{K} log π(a^k)
# entropy_loss = pi.entropy().mean()
K = 2
_, sample_logp = pi.sample_and_log_prob(seed=entropy_key, sample_shape=(K,))
entropy = -jnp.mean(sample_logp, axis=0).mean()
total_loss = (
actor_loss + ppo_value_coeff * value_loss - ppo_entropy_coeff * entropy
)
# ----- diagnostics -----
approx_kl = jnp.mean((jnp.exp(log_ratio) - 1.0) - log_ratio)
explained_var = 1.0 - jnp.var(td.returns - td.value) / (
jnp.var(td.returns) + 1e-8
)
policy_std = jnp.mean(jnp.exp(model.log_std.value))
aux = {
# losses
"actor_loss": actor_loss,
"value_loss": value_loss,
"entropy": entropy,
# policy diagnostics
"approx_KL": approx_kl,
"ratio": ratio,
"policy_std": policy_std,
# value diagnostics
"explained_variance": explained_var,
"returns": td.returns,
"score": td.reward,
}
return total_loss, aux
[docs]
@staticmethod
@jax.jit
def epoch(tr: "PPOTrainer", epoch: ArrayLike):
r"""
Run one PPO training epoch.
Returns
-------
(PPOTrainer, TrajectoryData)
Updated trainer state and the most recent time-stacked rollout.
"""
beta_t = tr.importance_sampling_beta + tr.anneal_importance_sampling_beta * (
1.0 - tr.importance_sampling_beta
) * (epoch / tr.num_epochs)
key, sample_key, reset_key, entropy_keys_key = jax.random.split(tr.key, 4)
subkeys = jax.random.split(reset_key, tr.num_envs)
entropy_keys = jax.random.split(entropy_keys_key, tr.num_minibatches)
# 0) Reset the environment and LSTM carry
model, optimizer, *rest = nnx.merge(tr.graphdef, tr.graphstate)
model.reset(
shape=(tr.num_envs, tr.env.max_num_agents), mask=tr.env.done(tr.env)
)
tr = replace(
tr,
key=key,
env=jax.vmap(tr.env.reset_if_done)(tr.env, tr.env.done(tr.env), subkeys),
graphstate=nnx.state((model, optimizer, *rest)),
)
# 1) Gather data -> shape: (time, num_envs, num_agents, *)
tr, td = tr.trajectory_rollout(tr, tr.num_steps_epoch)
# Reshape data (time, num_envs, num_agents, *) -> (time, num_envs*num_agents, *)
td = jax.tree_util.tree_map(
lambda x: x.reshape((x.shape[0], x.shape[1] * x.shape[2]) + x.shape[3:]), td
)
# 2) Compute advantages
td = tr.compute_advantages(
td,
tr.advantage_rho_clip,
tr.advantage_c_clip,
tr.advantage_gamma,
tr.advantage_lambda,
)
# 3) Importance sampling
prio_weights = jnp.nan_to_num(
jnp.power(jnp.abs(td.advantage).sum(axis=0), tr.importance_sampling_alpha),
False,
0.0,
0.0,
)
prio_probs = prio_weights / (prio_weights.sum() + 1.0e-8)
idxs = jax.random.choice(
sample_key,
a=tr.num_segments,
p=prio_probs,
shape=(tr.num_minibatches, tr.minibatch_size),
)
@partial(jax.jit, donate_argnames=("idx",))
def train_batch(carry: Tuple, idx: jax.Array) -> Tuple:
# 4.0) Unpack model
tr, td, weights = carry
model, optimizer, metrics, *rest = nnx.merge(tr.graphdef, tr.graphstate)
idx, entropy_key = idx
# 4.1) Importance sampling
mb_td = jax.tree_util.tree_map(lambda x: jnp.take(x, idx, axis=1), td)
seg_w = jnp.power(
tr.num_segments * jnp.take(weights[None, :], idx, axis=1), -beta_t
)
# 4.2) Compute gradients
(loss, aux), grads = nnx.value_and_grad(tr.loss_fn, has_aux=True)(
model,
mb_td,
seg_w,
tr.advantage_rho_clip,
tr.advantage_c_clip,
tr.advantage_gamma,
tr.advantage_lambda,
tr.ppo_clip_eps,
tr.ppo_value_coeff,
tr.ppo_entropy_coeff,
entropy_key,
)
# 4.3) Train model
model.train()
optimizer.update(model, grads)
# 4.4) Log metrics
metrics.update(
loss=loss,
grad_norm=optax.global_norm(grads),
**aux,
)
# 4.5) Return updated model
tr = replace(tr, graphstate=nnx.state((model, optimizer, metrics, *rest)))
return (tr, td, weights), loss
# 4) Loop over mini batches
(tr, td, prio_probs), loss = jax.lax.scan(
train_batch, (tr, td, prio_probs), xs=(idxs, entropy_keys), unroll=4
)
return tr, td