Source code for jaxdem.rl.trainers

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Interface for defining reinforcement learning model trainers.
"""

from __future__ import annotations

import jax
import jax.numpy as jnp
from jax.typing import ArrayLike

from typing import TYPE_CHECKING, Tuple, Any, Sequence

from abc import ABC, abstractmethod
from dataclasses import dataclass, replace
from functools import partial

from flax import nnx

from ...factory import Factory

if TYPE_CHECKING:
    from ..environments import Environment


[docs] @jax.tree_util.register_dataclass @dataclass(kw_only=True, slots=True) class TrajectoryData: """ Container for rollout data (single step or stacked across time). """ obs: jax.Array """ Observations. """ action: jax.Array """ Actions sampled from the policy. """ reward: jax.Array r""" Immediate rewards :math:`r_t`. """ done: jax.Array """ Episode-termination flags (boolean). """ value: jax.Array r""" Baseline value estimates :math:`V(s_t)`. """ log_prob: jax.Array r""" Behavior-policy log-probabilities :math:`\log \pi_b(a_t \mid s_t)` at collection time. """ new_log_prob: jax.Array r""" Target-policy log-probabilities :math:`\log \pi(a_t \mid s_t)` after policy update. Fill with ``log_prob`` during on-policy collection; must be recomputed after updates. """ advantage: jax.Array r""" Advantages :math:`A_t`. """ returns: jax.Array """ Return targets (e.g., GAE or V-trace targets). """
[docs] @jax.tree_util.register_dataclass @dataclass(slots=True, frozen=True) class Trainer(Factory, ABC): """ Base class for reinforcement learning trainers. This class holds the environment and model state (Flax NNX GraphDef/GraphState). It provides rollout utilities (:meth:`step`, :meth:`trajectory_rollout`) and a general advantage computation method (:meth:`compute_advantages`). Subclasses must implement algorithm-specific training logic in :meth:`epoch`. Example ------- To define a custom trainer, inherit from :class:`Trainer` and implement its abstract methods: >>> @Trainer.register("myCustomTrainer") >>> @jax.tree_util.register_dataclass >>> @dataclass(slots=True, frozen=True) >>> class MyCustomTrainer(Trainer): ... """ env: "Environment" """ Environment object. """ graphdef: nnx.GraphDef """ Static graph definition of the model/optimizer. """ graphstate: nnx.GraphState """ Mutable state (parameters, optimizer state, RNGs, etc.). """ key: ArrayLike """ PRNGKey used to sample actions and for other stochastic operations. """ advantage_gamma: jax.Array r""" Discount factor :math:`\gamma \in [0, 1]`. """ advantage_lambda: jax.Array r""" Generalized Advantage Estimation parameter :math:`\lambda \in [0, 1]`. """ advantage_rho_clip: jax.Array r""" V-trace :math:`\bar{\rho}` (importance weight clip for the TD term). """ advantage_c_clip: jax.Array r""" V-trace :math:`\bar{c}` (importance weight clip for the recursion/trace term). """ @property def model(self): """Return the live model rebuilt from (graphdef, graphstate).""" model, *rest = nnx.merge(self.graphdef, self.graphstate) return model @property def optimizer(self): """Return the optimizer rebuilt from (graphdef, graphstate).""" model, optimizer, *rest = nnx.merge(self.graphdef, self.graphstate) return optimizer
[docs] @staticmethod @jax.jit def step(tr: "Trainer") -> Tuple["Trainer", "TrajectoryData"]: """ Take one environment step and record a single-step trajectory. Returns ------- (Trainer, TrajectoryData) Updated trainer and the new single-step trajectory record. Trajectory data shape: (N_envs, N_agents, *) """ key, subkey = jax.random.split(tr.key) tr = replace(tr, key=key) model, *rest = nnx.merge(tr.graphdef, tr.graphstate) model.eval() obs = tr.env.observation(tr.env) pi, value = model(obs, sequence=False) action, log_prob = pi.sample_and_log_prob(seed=subkey) tr = replace(tr, env=tr.env.step(tr.env, action)) reward = tr.env.reward(tr.env) done = tr.env.done(tr.env) # new_log_prob, advantage, and returns need to be computed later # Shape -> (N_agents, *) traj = TrajectoryData( obs=obs, action=action, reward=reward, done=jnp.broadcast_to(done[..., None], reward.shape), value=jnp.squeeze(value, -1), log_prob=log_prob, new_log_prob=log_prob, advantage=log_prob, returns=log_prob, ) tr = replace(tr, graphstate=nnx.state((model, *rest))) return tr, traj
[docs] @staticmethod def reset_model( tr: "Trainer", shape: Sequence[int] | None = None, mask: jax.Array | None = None, ) -> "Trainer": """ Reset a model's persistent recurrent state (e.g., LSTM carry) for all environments/agents and persist the mutation back into the trainer. Parameters ---------- tr : Trainer Trainer carrying the environment and NNX graph state. The target carry shape is inferred as ``(tr.num_envs, tr.env.max_num_agents)`` if not specified. mask : jax.Array, optional Boolean mask selecting which (env, agent) entries to reset. A value of ``True`` resets that entry. The mask may be shape ``(num_envs, num_agents)`` or any shape broadcastable to it. If ``None``, all entries are reset. Returns ------- Trainer A new trainer with the updated ``graphstate``. """ ...
[docs] @staticmethod @partial(jax.jit, static_argnames=("num_steps_epoch", "unroll")) def trajectory_rollout( tr: "Trainer", num_steps_epoch: int, unroll: int = 8 ) -> Tuple["Trainer", "TrajectoryData"]: r""" Roll out :math:`T = \text{num_steps_epoch}` environment steps using :func:`jax.lax.scan`. Parameters ---------- tr : Trainer The trainer carrying model state. num_steps_epoch : int Number of steps to roll out. unroll : int Number of loop iterations to unroll for compilation speed. Returns ------- (Trainer, TrajectoryData) The final trainer and a :class:`TrajectoryData` instance whose fields are stacked along time (leading dimension :math:`T = \text{num_steps_epoch}`). """ return jax.lax.scan( lambda tr, _: Trainer.step(tr), tr, None, length=num_steps_epoch, unroll=unroll, )
[docs] @staticmethod @partial(jax.jit, static_argnames=("unroll",)) def compute_advantages( td: "TrajectoryData", advantage_rho_clip: jax.Array, advantage_c_clip: jax.Array, advantage_gamma: jax.Array, advantage_lambda: jax.Array, unroll: int = 8, ) -> "TrajectoryData": r""" Compute advantages and return targets with V-trace-style off-policy correction or generalized advantage estimation (GAE). Let the behavior policy be :math:`\pi_b` and the target policy be :math:`\pi`. Define importance ratios per step: .. math:: \rho_t = \exp\big( \log \pi(a_t \mid s_t) - \log \pi_b(a_t \mid s_t) \big) and their clipped versions :math:`\bar{\rho}, \bar{c}`: .. math:: \hat{\rho}_t = \min(\rho_t, \bar{\rho}), \quad \hat{c}_t = \min(\rho_t, \bar{c}). We form a TD-like residual with an off-policy correction: .. math:: \delta_t = \hat{\rho}_t \, r_t + \gamma V(s_{t+1})(1 - \text{done}_t) - V(s_t) and propagate a GAE-style trace using :math:`\hat{c}_t`: .. math:: A_t = \delta_t + \gamma \lambda (1 - \text{done}_t) \hat{c}_t A_{t+1} Finally, the return targets are: .. math:: \text{returns}_t = A_t + V(s_t) Notes ----- • When :math:`\pi_b = \pi` (i.e. ``TrajectoryData.log_prob == TrajectoryData.new_log_prob``) and :math:`\bar{\rho} = \bar{c} = 1`, this function reduces to standard GAE. Returns ------- (TrajectoryData) :class:`TrajectoryData` with new ``advantage`` and ``returns``. References ---------- - Schulman et al., *High-Dimensional Continuous Control Using Generalized Advantage Estimation*, 2015/2016 - Espeholt et al., *IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures*, 2018 """ last_value = td.value[-1] gae0 = jnp.zeros_like(last_value) def calculate_advantage(gae_and_next_value: Tuple, td: TrajectoryData) -> Tuple: gae, next_value = gae_and_next_value ratio = jnp.exp(td.new_log_prob - td.log_prob) rho = jnp.minimum(ratio, advantage_rho_clip) c = jnp.minimum(ratio, advantage_c_clip) delta = ( rho * td.reward + advantage_gamma * next_value * (1 - td.done) - td.value ) gae = delta + advantage_gamma * advantage_lambda * (1 - td.done) * c * gae return (gae, td.value), gae _, adv = jax.lax.scan( calculate_advantage, (gae0, last_value), td, reverse=True, unroll=unroll ) return replace(td, advantage=adv, returns=adv + td.value)
[docs] @staticmethod @abstractmethod @jax.jit def epoch(tr: "Trainer", epoch: ArrayLike) -> Any: """ Run one training epoch. Subclasses must implement this with their algorithm-specific logic. """ raise NotImplementedError
[docs] @staticmethod @abstractmethod def train(tr) -> Any: """ Training loop """ raise NotImplementedError
from .PPOtrainer import PPOTrainer __all__ = ["PPOTrainer"]