jaxdem.rl.trainers#

Interface for defining reinforcement learning model trainers.

Classes

Trainer(env, graphdef, graphstate, key, ...)

Base class for reinforcement learning trainers.

TrajectoryData(*, obs, action, reward, done, ...)

Container for rollout data (single step or stacked across time).

class jaxdem.rl.trainers.PPOTrainer(env: Environment, graphdef: nnx.GraphDef, graphstate: nnx.GraphState, key: ArrayLike, advantage_gamma: jax.Array, advantage_lambda: jax.Array, advantage_rho_clip: jax.Array, advantage_c_clip: jax.Array, ppo_clip_eps: jax.Array, ppo_value_coeff: jax.Array, ppo_entropy_coeff: jax.Array, importance_sampling_alpha: jax.Array, importance_sampling_beta: jax.Array, anneal_importance_sampling_beta: jax.Array, num_envs: int = 4096, num_epochs: int = 2000, num_steps_epoch: int = 128, num_minibatches: int = 8, minibatch_size: int = 512, num_segments: int = 4096)[source]#

Bases: Trainer

Proximal Policy Optimization (PPO) trainer in PufferLib style.

This trainer implements the PPO algorithm with clipped surrogate objectives, value-function loss, entropy regularization, and importance-sampling reweighting.

Loss function

Given a trajectory batch with actions \(a_t\), states \(s_t\), rewards \(r_t\), advantages \(A_t\), and old log-probabilities \(\log \pi_{\theta_\text{old}}(a_t \mid s_t)\), we define:

  • Probability ratio:

    \[r_t(\theta) = \exp\big( \log \pi_\theta(a_t \mid s_t) - \log \pi_{\theta_\text{old}}(a_t \mid s_t) \big)\]
  • Clipped policy loss:

    \[L^{\text{policy}}(\theta) = - \mathbb{E}_t \Big[ \min\big( r_t(\theta) A_t,\; \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \big) \Big]\]

    where \(\epsilon\) is the PPO clipping parameter.

  • Value-function loss (with clipping):

    \[L^{\text{value}}(\theta) = \tfrac{1}{2} \mathbb{E}_t \Big[ \max\big( (V_\theta(s_t) - R_t)^2,\; (\text{clip}(V_\theta(s_t), V_{\theta_\text{old}}(s_t) - \epsilon, V_{\theta_\text{old}}(s_t) + \epsilon) - R_t)^2 \big) \Big]\]

    where \(R_t = A_t + r_t\) are return targets.

  • Entropy bonus:

    \[L^{\text{entropy}}(\theta) = \mathbb{E}_t \big[ \mathcal{H}[\pi_\theta(\cdot \mid s_t)] \big]\]

    which encourages exploration.

  • Total loss:

    \[L(\theta) = L^{\text{policy}}(\theta) + c_v L^{\text{value}}(\theta) - c_e L^{\text{entropy}}(\theta)\]

    where \(c_v\) and \(c_e\) are coefficients for the value and entropy terms.

Prioritized minibatch sampling and importance weighting

This trainer uses a prioritized categorical distribution over environments to form minibatches. For each environment index \(i \in \{1,\dots,N\}\), we define a priority from the trajectory advantages:

\[\tilde{p}_i \;=\; \Big\| A_{\cdot,i} \Big\|_1^{\,\alpha} \quad\text{with}\quad \Big\| A_{\cdot,i} \Big\|_1 \;=\; \sum_{t=1}^{T} \big|A_{t,i}\big|,\]

where \(\alpha \ge 0\) (importance_sampling_alpha) controls the strength of prioritization. We then form a categorical sampling distribution

\[P(i) \;=\; \frac{\tilde{p}_i}{\sum_{k=1}^{N} \tilde{p}_k},\]

and sample indices \(\{i\}\) to create each minibatch (jax.random.choice() with probabilities \(P(i)\)). This mirrors Prioritized Experience Replay (PER), where \(\tilde{p}\) is derived from TD-error magnitude; here we use the per-environment advantage magnitude as a proxy for learning progress. This design is also inspired by recent large-scale self-play systems for autonomous driving.

To correct sampling bias we apply PER-style importance weights (importance_sampling_beta with optional linear annealing):

\[w_i(\beta_t) \;=\; \Big(N \, P(i)\Big)^{-\beta_t}, \qquad \beta_t \in [0,1].\]

In classical PER, \(w_i\) is often normalized by \(\max_j w_j\) to keep the scale bounded; in this implementation we omit that normalization and use \(w_i\) directly. The minibatch advantages are standardized and reweighted with these IS weights before the PPO loss:

\[\hat{A}_{t,i} \;=\; w_i(\beta_t)\; \frac{A_{t,i} - \mu_{\text{mb}}(A)}{\sigma_{\text{mb}}(A)+\varepsilon}.\]

Off-policy correction of advantages (V-trace)

After sampling, we recompute log-probabilities under the current policy (td.new_log_prob = pi.log_prob(td.action)) and compute targets/advantages with a V-trace–style off-policy correction in compute_advantages().

References

  • Schulman et al., Proximal Policy Optimization Algorithms, 2017.

  • Espeholt et al., IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures, ICML 2018.

  • Schulman et al., High-Dimensional Continuous Control Using Generalized Advantage Estimation, 2015/2016.

  • Schaul et al., Prioritized Experience Replay, ICLR 2016.

  • Cusumano-Towner et al., Robust Autonomy Emerges from Self-Play, ICML 2025.

ppo_clip_eps: jax.Array#

PPO clipping parameter \(\epsilon\) used for both the policy ratio clip and (value-function clip.

ppo_value_coeff: jax.Array#

Coefficient \(c_v\) scaling the value-function loss term in the total loss.

ppo_entropy_coeff: jax.Array#

Coefficient \(c_e\) scaling the entropy bonus (encourages exploration).

importance_sampling_alpha: jax.Array#

Prioritization strength \(\alpha \ge 0\) for minibatch sampling; higher values put more probability mass on envs with larger advantages.

importance_sampling_beta: jax.Array#

Initial PER importance-weight exponent \(\beta \in [0,1]\) used in \(w_i(\beta) = (N P(i))^{-\beta}\); compensates sampling bias.

anneal_importance_sampling_beta: jax.Array#

If nonzero/True, linearly anneals \(\beta\) toward 1 across training (more correction later in training).

num_envs: int#

Number of vectorized environments \(N\) running in parallel.

num_epochs: int#

Number of PPO training epochs (outer loop count).

num_steps_epoch: int#

Rollout horizon \(T\) per epoch; total collected steps = \(N \times T\).

num_minibatches: int#

Number of minibatches per epoch used for PPO updates.

minibatch_size: int#

Minibatch size (number of env indices sampled per update); typically \(N / \text{num_minibatches}\).

num_segments: int#

Number of vectorized environments times max number of agents.

classmethod Create(env: Environment, model: Model, key: ArrayLike = Array((), dtype=key<fry>) overlaying: [0 1], learning_rate: float = 0.02, max_grad_norm: float = 0.4, ppo_clip_eps: float = 0.2, ppo_value_coeff: float = 0.5, ppo_entropy_coeff: float = 0.03, importance_sampling_alpha: float = 0.4, importance_sampling_beta: float = 0.1, advantage_gamma: float = 0.99, advantage_lambda: float = 0.95, advantage_rho_clip: float = 0.5, advantage_c_clip: float = 0.5, num_envs: int = 2048, num_epochs: int = 1000, num_steps_epoch: int = 64, num_minibatches: int = 10, minibatch_size: Optional[int] = None, accumulate_n_gradients: int = 1, clip_actions: bool = False, clip_range: Tuple[float, float] = (-0.2, 0.2), anneal_learning_rate: bool = True, learning_rate_decay_exponent: float = 2.0, learning_rate_decay_min_fraction: float = 0.01, anneal_importance_sampling_beta: bool = True, optimizer=<function muon>) Self[source][source]#
static epoch(tr: PPOTrainer, epoch: Array | ndarray | bool | number | bool | int | float | complex)[source][source]#

Run one PPO training epoch.

Returns:

Updated trainer state and the most recent time-stacked rollout.

Return type:

(PPOTrainer, TrajectoryData)

static loss_fn(model: Model, td: TrajectoryData, seg_w: jax.Array, advantage_rho_clip: jax.Array, advantage_c_clip: jax.Array, advantage_gamma: jax.Array, advantage_lambda: jax.Array, ppo_clip_eps: jax.Array, ppo_value_coeff: jax.Array, ppo_entropy_coeff: jax.Array, entropy_key: jax.Array)[source][source]#

Compute the PPO minibatch loss.

Parameters:
  • model (Model) – Live model rebuilt by NNX for this step.

  • td (TrajectoryData) – Time-stacked trajectory mini batch (e.g., shape [T, B, ...]).

  • seg_w (jax.Array) – Weights for advantage normalization.

Returns:

Scalar loss to be minimized.

Return type:

jax.Array

See also

Main

classmethod registry_name() str[source]#
static reset_model(tr: Trainer, shape: Sequence[int] | None = None, mask: Array | None = None) Trainer[source][source]#

Reset a model’s persistent recurrent state (e.g., LSTM carry) for all environments/agents and persist the mutation back into the trainer.

Parameters:
  • tr (Trainer) – Trainer carrying the environment and NNX graph state. The target carry shape is inferred as (tr.num_envs, tr.env.max_num_agents) if not specified.

  • mask (jax.Array, optional) – Boolean mask selecting which (env, agent) entries to reset. A value of True resets that entry. The mask may be shape (num_envs, num_agents) or any shape broadcastable to it. If None, all entries are reset.

Returns:

A new trainer with the updated graphstate.

Return type:

Trainer

static train(tr: PPOTrainer, verbose=True)[source][source]#

Training loop

property type_name: str[source]#

Modules

PPOtrainer

Implementation of PPO algorithm.