jaxdem.rl#

JaxDEM reinforcement learning (RL) module. Contains ML models, environments and trainers with RL algorithms like PPO.

class jaxdem.rl.ActionSpace#

Bases: Factory

Registry/namespace for action-space constraints implemented as distrax.Bijector objects.

These bijectors are intended to be wrapped around a base policy distribution (e.g., MultivariateNormalDiag) via distrax.Transformed, so that sampling and log-probabilities are correctly adjusted using the bijector’s ‘forward_and_log_det’ / ‘inverse_and_log_det’ methods. See Distrax/TFP bijector interface for details on shape semantics and ‘event_ndims_in/out’.

Example:#

To define a custom action space, inherit from distrax.Bijector and ActionSpace and implement its abstract methods:

>>> @ActionSpace.register("myCustomActionSpace")
>>> class MyCustomActionSpace(distrax.Bijector, ActionSpace):
        ...
property kws: dict[str, Any][source]#
log_det_expectation(mean: Array, std: Array) Array[source]#

Compute \(\mathbb{E}_{X}[\log|\det J_f(X)|]\) where \(X \sim \mathcal{N}(\text{mean}, \text{diag}(\text{std}^2))\).

Subclasses should override this to enable Transformed.entropy().

class jaxdem.rl.Environment(state: State, system: System, env_params: dict[str, Any])#

Bases: Factory, ABC

Defines the interface for reinforcement-learning environments.

  • Let A be the number of agents (A ≥ 1). Single-agent environments still use A=1.

  • Observations and actions are flattened per agent to fixed sizes. Use action_space_shape to reshape inside the environment if needed.

Required shapes

  • Observation: (A, observation_space_size)

  • Action (input to step()): (A, action_space_size)

  • Reward: (A,)

  • Done: scalar boolean for the whole environment

Todo: - Truncated data field: per-agent termination flag - Render method

Example:#

To define a custom environment, inherit from Environment and implement the abstract methods:

>>> @Environment.register("Environment")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class MyCustomEnv(Environment):
    ...
state: State#

Simulation state.

system: System#

Simulation system configuration.

env_params: dict[str, Any]#

Environment-specific parameters.

abstractmethod static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source]#

Initialize the environment to a valid start state.

Parameters:
  • env ('MyCustomEnv') – Instance of the environment.

  • key (jax.random.PRNGKey) – JAX random number generator key.

Returns:

Freshly initialized environment.

Return type:

Environment

static reset_if_done(env: Environment, done: Array, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source]#

Conditionally resets the environment if the environment has reached a terminal state.

This method checks the done flag and, if True, calls the environment’s reset method to reinitialize the state. Otherwise, it returns the current environment unchanged.

Parameters:
  • env (Environment) – The current environment instance.

  • done (jax.Array) – A boolean flag indicating whether the environment has reached a terminal state.

  • key (jax.random.PRNGKey) – JAX random number generator key used for reinitialization.

Returns:

Either the freshly reset environment (if done is True) or the unchanged environment (if done is False).

Return type:

Environment

abstractmethod static step(env: Environment, action: Array) Environment[source]#

Advance the simulation by one step using per-agent actions.

Parameters:
  • env (Environment) – The current environment.

  • action (jax.Array) – The vector of actions each agent in the environment should take.

Returns:

The updated environment state.

Return type:

Environment

abstractmethod static observation(env: Environment) Array[source]#

Returns the per-agent observation vector.

Parameters:

env (Environment) – The current environment.

Returns:

Vector corresponding to the environment observation.

Return type:

jax.Array

abstractmethod static reward(env: Environment) Array[source]#

Returns the per-agent immediate rewards.

Parameters:

env (Environment) – The current environment.

Returns:

Vector corresponding to all the agent’s rewards based on the current environment state.

Return type:

jax.Array

abstractmethod static done(env: Environment) Array[source]#

Returns a boolean indicating whether the environment has ended.

Parameters:

env (Environment) – The current environment.

Returns:

A bool indicating when the environment ended

Return type:

jax.Array

static info(env: Environment) dict[str, Any][source]#

Return auxiliary diagnostic information.

By default, returns an empty dict. Subclasses may override to provide environment specific information.

Parameters:

env (Environment) – The current state of the environment.

Returns:

A dictionary with additional information about the environment.

Return type:

Dict

property num_envs: int[source]#

Number of batched environments.

property max_num_agents: int[source]#

Maximum number of active agents in the environment.

property action_space_size: int[source]#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

property action_space_shape: tuple[int][source]#

Original per-agent action shape (useful for reshaping inside the environment).

property observation_space_size: int[source]#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).

class jaxdem.rl.Model(*args: Any, **kwargs: Any)#

Bases: Factory, Module, ABC

The base interface for defining reinforcement learning models. Acts as a namespace.

Models map observations to an action distribution and a value estimate.

Example:#

To define a custom model, inherit from Model and implement its abstract methods:

>>> @Model.register("myCustomModel")
>>> class MyCustomModel(Model):
        ...
property metadata: dict[str, Any][source]#
reset(shape: tuple[int, ...], mask: Array | None = None) None[source]#

Reset the persistent LSTM carry.

Parameters:
  • shape (tuple[int, ...]) – Leading dims for the carry, e.g. (num_envs, num_agents).

  • mask (optional bool array) – True where to reset entries. Shape (num_envs)

class jaxdem.rl.Trainer(env: Environment, graphdef: nnx.GraphDef[Any], graphstate: nnx.GraphState, key: ArrayLike, advantage_gamma: jax.Array, advantage_lambda: jax.Array, advantage_rho_clip: jax.Array, advantage_c_clip: jax.Array)#

Bases: Factory, ABC

Base class for reinforcement learning trainers.

This class holds the environment and model state (Flax NNX GraphDef/GraphState). It provides rollout utilities (step(), trajectory_rollout()) and a general advantage computation method (compute_advantages()). Subclasses must implement algorithm-specific training logic in epoch().

Example:#

To define a custom trainer, inherit from Trainer and implement its abstract methods:

>>> @Trainer.register("myCustomTrainer")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class MyCustomTrainer(Trainer):
        ...
env: Environment#

Environment object.

graphdef: nnx.GraphDef[Any]#

Static graph definition of the model/optimizer.

graphstate: nnx.GraphState#

Mutable state (parameters, optimizer state, RNGs, etc.).

key: ArrayLike#

PRNGKey used to sample actions and for other stochastic operations.

advantage_gamma: jax.Array#

Discount factor \(\gamma \in [0, 1]\).

advantage_lambda: jax.Array#

Generalized Advantage Estimation parameter \(\lambda \in [0, 1]\).

advantage_rho_clip: jax.Array#

V-trace \(\bar{\rho}\) (importance weight clip for the TD term).

advantage_c_clip: jax.Array#

V-trace \(\bar{c}\) (importance weight clip for the recursion/trace term).

property model: Any[source]#

Return the live model rebuilt from (graphdef, graphstate).

static step(env: Environment, graphdef: nnx.GraphDef[Any], graphstate: nnx.GraphState, key: jax.Array, skip_frames: int = 0) tuple[tuple[Environment, nnx.GraphState, jax.Array], TrajectoryData][source]#

Take one environment step (possibly repeating action) and record a single-step trajectory.

Parameters:
  • env (Environment) – The trainer carrying model state.

  • graphdef (nnx.GraphDef) – Python part of the nnx model

  • graphstate (nnx.GraphState) – State of the nnx model

  • key (jax.Array) – Jax random key

  • skip_frames (int) – Number of additional frames to repeat the action.

Returns:

Updated state and the new single-step trajectory. Trajectory data is shaped (N_envs, N_agents, …).

Return type:

Tuple[Tuple[Environment, nnx.GraphState, jax.Array], TrajectoryData]

static trajectory_rollout(env: Environment, graphdef: nnx.GraphDef[Any], graphstate: nnx.GraphState, key: jax.Array, num_steps_epoch: int, unroll: int = 8, skip_frames: int = 0) tuple[Environment, nnx.GraphState, jax.Array, TrajectoryData][source]#

Roll out \(T = \text{num_steps_epoch}\) environment steps using jax.lax.scan().

Parameters:
  • env (Environment) – The trainer carrying model state.

  • graphdef (nnx.GraphDef) – Python part of the nnx model

  • graphstate (nnx.GraphState) – State of the nnx model

  • key (jax.Array) – Jax random key

  • num_steps_epoch (int) – Number of steps to roll out.

  • unroll (int) – Number of loop iterations to unroll for compilation speed.

  • skip_frames (int) – Number of frames to skip (repeat action) per observation.

Returns:

The final trainer and a TrajectoryData instance whose fields are stacked along time (leading dimension \(T = \text{num_steps_epoch}\)).

Return type:

Tuple[Environment, nnx.GraphState, jax.Array, TrajectoryData]

static compute_advantages(value: Array, reward: Array, ratio: Array, done: Array, advantage_rho_clip: Array, advantage_c_clip: Array, advantage_gamma: Array, advantage_lambda: Array, unroll: int = 8) tuple[Array, Array][source]#

Compute V-trace/GAE advantages and return targets.

Given a policy \(\pi\), define per-step importance ratios and clipped versions:

\[\rho_t = \exp\big( \log \pi_\theta(a_t \mid s_t) - \log \pi_{\theta_\text{old}}(a_t \mid s_t) \big)\]

and their clipped versions \(\hat{\rho}, \hat{c}\):

\[\hat{\rho}_t = \min(\rho_t, \bar{\rho}), \quad \hat{c}_t = \min(\rho_t, \bar{c}).\]

We form a TD-like residual with an off-policy correction:

\[\delta_t = \hat{\rho}_t \, r_t + \gamma V(s_{t+1})(1 - \text{done}_t) - V(s_t)\]

and propagate a GAE-style trace using \(\hat{c}_t\):

\[A_t = \delta_t + \gamma \lambda (1 - \text{done}_t) \hat{c}_t A_{t+1}\]

Finally, the return targets are:

\[\text{returns}_t = A_t + V(s_t)\]

Notes

When \(\pi_\theta = \pi_{\theta_\text{old}}\) (i.e. ratio==1) and \(\bar{\rho} = \bar{c} = 1\), this function reduces to standard GAE.

Returns:

Computed advantage and returns.

Return type:

Tuple[jax.Array, jax.Array]

References

  • Schulman et al., High-Dimensional Continuous Control Using Generalized Advantage Estimation, 2015/2016

  • Espeholt et al., IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures, 2018

abstractmethod static epoch(tr: Trainer, epoch: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Any[source]#

Run one training epoch.

Subclasses must implement this with their algorithm-specific logic.

abstractmethod static train(tr: Trainer, *args: Any, **kwargs: Any) Any[source]#

Training loop.

Subclasses must implement this with their algorithm-specific logic.

class jaxdem.rl.TrajectoryData(*, obs: Array, action: Array, value: Array, log_prob: Array, ratio: Array, reward: Array, done: Array)#

Bases: object

Container for rollout data (single step or stacked across time).

obs: Array#

Observations.

action: Array#

Actions sampled from the policy.

value: Array#

Baseline value estimates \(V(s_t)\).

log_prob: Array#

Behavior-policy log-probabilities \(\log \pi_b(a_t \mid s_t)\) at collection time.

ratio: Array#

Ratio between behavior-policy probabilities \(\exp\big( \log \pi_\theta(a_t \mid s_t) - \log \pi_{\theta_\text{old}}(a_t \mid s_t) \big)\).

reward: Array#

Immediate rewards \(r_t\).

done: Array#

Episode-termination flags (boolean).

jaxdem.rl.clip_action_env(env: Environment, min_val: float = -1.0, max_val: float = 1.0) Environment[source]#

Wrap an environment so that its step method clips the action before calling the original step.

jaxdem.rl.is_wrapped(env: Environment) bool[source]#

Check whether an environment instance is a wrapped environment.

Parameters:

env (Environment) – The environment instance to check.

Returns:

True if the environment is wrapped (i.e., has a _base_env_cls attribute), False otherwise.

Return type:

bool

jaxdem.rl.unwrap(env: Environment) Environment[source]#

Unwrap an environment to its original base class while preserving all current field values.

Parameters:

env (Environment) – The wrapped environment instance.

Returns:

A new instance of the original base environment class with the same field values as the wrapped instance.

Return type:

Environment

jaxdem.rl.vectorise_env(env: Environment) Environment[source]#

Promote an environment instance to a parallel version by applying jax.vmap(…) to its static methods.

Modules

actionSpaces

Interface for defining bijectors used to constrain the policy probability distribution.

envWrappers

Contains wrappers for modifying RL environments.

environments

Reinforcement-learning environment interface.

models

Interface for defining reinforcement learning models.

trainers

Interface for defining reinforcement learning model trainers.