jaxdem.rl#

JaxDEM reinforcement learning (RL) module. Contains ML models, environments and trainers with RL algorithms like PPO.

class jaxdem.rl.Environment(state: State, system: System, env_params: Dict[str, Any])[source]#

Bases: Factory, ABC

Defines the interface for reinforcement-learning environments.

  • Let A be the number of agents (A ≥ 1). Single-agent environments still use A=1.

  • Observations and actions are flattened per agent to fixed sizes. Use action_space_shape to reshape inside the environment if needed.

Required shapes

  • Observation: (A, observation_space_size)

  • Action (input to step()): (A, action_space_size)

  • Reward: (A,)

  • Done: scalar boolean for the whole environment

TODO:

  • Truncated data field: per-agent termination flag

  • Render method

Example

To define a custom environment, inherit from Environment and implement the abstract methods:

>>> @Environment.register("MyCustomEnv")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class MyCustomEnv(Environment):
    ...
state: State#

Simulation state.

system: System#

Simulation system configuration.

env_params: Dict[str, Any]#

Environment-specific parameters.

property action_space_shape: Tuple[int][source]#

Original per-agent action shape (useful for reshaping inside the environment).

property action_space_size: int[source]#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

abstractmethod static done(env: Environment) Array[source][source]#

Returns a boolean indicating whether the environment has ended.

Parameters:

env (Environment) – The current environment.

Returns:

A bool indicating when the environment ended

Return type:

jax.Array

static info(env: Environment) Dict[str, Any][source][source]#

Return auxiliary diagnostic information.

By default, returns an empty dict. Subclasses may override to provide environment specific information.

Parameters:

env (Environment) – The current state of the environment.

Returns:

A dictionary with additional information about the environment.

Return type:

Dict

property max_num_agents: int[source]#

Maximum number of active agents in the environment.

property num_envs: int[source]#

Number of batched environments.

abstractmethod static observation(env: Environment) Array[source][source]#

Returns the per-agent observation vector.

Parameters:

env (Environment) – The current environment.

Returns:

Vector corresponding to the environment observation.

Return type:

jax.Array

property observation_space_size: int[source]#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).

abstractmethod static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source][source]#

Initialize the environment to a valid start state.

Parameters:
  • env (Environment) – Instance of the environment.

  • key (jax.random.PRNGKey) – JAX random number generator key.

Returns:

Freshly initialized environment.

Return type:

Environment

static reset_if_done(env: Environment, done: Array, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source][source]#

Conditionally resets the environment if the environment has reached a terminal state.

This method checks the done flag and, if True, calls the environment’s reset method to reinitialize the state. Otherwise, it returns the current environment unchanged.

Parameters:
  • env (Environment) – The current environment instance.

  • done (jax.Array) – A boolean flag indicating whether the environment has reached a terminal state.

  • key (jax.random.PRNGKey) – JAX random number generator key used for reinitialization.

Returns:

Either the freshly reset environment (if done is True) or the unchanged environment (if done is False).

Return type:

Environment

abstractmethod static reward(env: Environment) Array[source][source]#

Returns the per-agent immediate rewards.

Parameters:

env (Environment) – The current environment.

Returns:

Vector corresponding to all the agent’s rewards based on the current environment state.

Return type:

jax.Array

abstractmethod static step(env: Environment, action: Array) Environment[source][source]#

Advance the simulation by one step using per-agent actions.

Parameters:
  • env (Environment) – The current environment.

  • action (jax.Array) – The vector of actions each agent in the environment should take.

Returns:

The updated environment state.

Return type:

Environment

jaxdem.rl.vectorise_env(env: Environment) Environment[source][source]#

Promote an environment instance to a parallel version by applying jax.vmap(…) to its static methods.

jaxdem.rl.clip_action_env(env: Environment, min_val: float = -1.0, max_val: float = 1.0) Environment[source][source]#

Wrap an environment so that its step method clips the action before calling the original step.

jaxdem.rl.is_wrapped(env: Environment) bool[source][source]#

Check whether an environment instance is a wrapped environment.

Parameters:

env (Environment) – The environment instance to check.

Returns:

True if the environment is wrapped (i.e., has a _base_env_cls attribute), False otherwise.

Return type:

bool

jaxdem.rl.unwrap(env: Environment) Environment[source][source]#

Unwrap an environment to its original base class while preserving all current field values.

Parameters:

env (Environment) – The wrapped environment instance.

Returns:

A new instance of the original base environment class with the same field values as the wrapped instance.

Return type:

Environment

class jaxdem.rl.Model(*args: Any, **kwargs: Any)[source]#

Bases: Factory, Module, ABC

The base interface for defining reinforcement learning models. Acts as a name space.

Models map observations to an action distribution and a value estimate.

Example

To define a custom model, inherit from Model and implement its abstract methods:

>>> @Model.register("myCustomModel")
>>> class MyCustomModel(Model):
        ...
property metadata: Dict[source]#
reset(shape: Tuple, mask: Array | None = None)[source][source]#

Reset the persistent LSTM carry.

Parameters:
  • lead_shape (tuple[int, ...]) – Leading dims for the carry, e.g. (num_envs, num_agents).

  • mask (optional bool array) – True where to reset entries. Shape (num_envs)

class jaxdem.rl.Trainer(env: Environment, graphdef: nnx.GraphDef, graphstate: nnx.GraphState, key: ArrayLike, advantage_gamma: jax.Array, advantage_lambda: jax.Array, advantage_rho_clip: jax.Array, advantage_c_clip: jax.Array)[source]#

Bases: Factory, ABC

Base class for reinforcement learning trainers.

This class holds the environment and model state (Flax NNX GraphDef/GraphState). It provides rollout utilities (step(), trajectory_rollout()) and a general advantage computation method (compute_advantages()). Subclasses must implement algorithm-specific training logic in epoch().

Example

To define a custom trainer, inherit from Trainer and implement its abstract methods:

>>> @Trainer.register("myCustomTrainer")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class MyCustomTrainer(Trainer):
        ...
env: Environment#

Environment object.

graphdef: nnx.GraphDef#

Static graph definition of the model/optimizer.

graphstate: nnx.GraphState#

Mutable state (parameters, optimizer state, RNGs, etc.).

key: ArrayLike#

PRNGKey used to sample actions and for other stochastic operations.

advantage_gamma: jax.Array#

Discount factor \(\gamma \in [0, 1]\).

advantage_lambda: jax.Array#

Generalized Advantage Estimation parameter \(\lambda \in [0, 1]\).

advantage_rho_clip: jax.Array#

V-trace \(\bar{\rho}\) (importance weight clip for the TD term).

advantage_c_clip: jax.Array#

V-trace \(\bar{c}\) (importance weight clip for the recursion/trace term).

static compute_advantages(value: Array, reward: Array, ratio: Array, done: Array, advantage_rho_clip: Array, advantage_c_clip: Array, advantage_gamma: Array, advantage_lambda: Array, unroll: int = 8) Tuple[Array, Array][source][source]#

Compute V-trace/GAE advantages and return targets.

Given a policy \(\pi\), define per-step importance ratios and clipped versions:

\[\rho_t = \exp\big( \log \pi_\theta(a_t \mid s_t) - \log \pi_{\theta_\text{old}}(a_t \mid s_t) \big)\]

and their clipped versions \(\hat{\rho}, \hat{c}\):

\[\hat{\rho}_t = \min(\rho_t, \bar{\rho}), \quad \hat{c}_t = \min(\rho_t, \bar{c}).\]

We form a TD-like residual with an off-policy correction:

\[\delta_t = \hat{\rho}_t \, r_t + \gamma V(s_{t+1})(1 - \text{done}_t) - V(s_t)\]

and propagate a GAE-style trace using \(\hat{c}_t\):

\[A_t = \delta_t + \gamma \lambda (1 - \text{done}_t) \hat{c}_t A_{t+1}\]

Finally, the return targets are:

\[\text{returns}_t = A_t + V(s_t)\]

Notes

When \(\pi_\theta = \pi_{\theta_\text{old}}\) (i.e. ratio==1) and \(\bar{\rho} = \bar{c} = 1\), this function reduces to standard GAE.

Returns:

Computed advantage and retuns

Return type:

Tuple[jax.Array]

References

  • Schulman et al., High-Dimensional Continuous Control Using Generalized Advantage Estimation, 2015/2016

  • Espeholt et al., IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures, 2018

abstractmethod static epoch(tr: Trainer, epoch: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Any[source][source]#

Run one training epoch.

Subclasses must implement this with their algorithm-specific logic.

property model[source]#

Return the live model rebuilt from (graphdef, graphstate).

static step(env: Environment, graphdef: nnx.GraphDef, graphstate: nnx.GraphState, key: jax.Array) Tuple[Tuple['Environment', nnx.GraphState, jax.Array], 'TrajectoryData'][source][source]#

Take one environment step and record a single-step trajectory.

Parameters:
  • env (Environment) – The trainer carrying model state.

  • graphdef (nnx.GraphDef) – Python part of the nnx model

  • graphstate (nnx.GraphState) – State of the nnx model

  • key (jax.Array) – Jax random key

Returns:

Updated state and the new single-step trajectory. Trajectory data is shaped (N_envs, N_agents, …).

Return type:

Tuple[Tuple[Environment, nnx.GraphState, jax.Array], TrajectoryData]

abstractmethod static train(tr) Any[source][source]#

Training loop

Subclasses must implement this with their algorithm-specific logic.

static trajectory_rollout(env: Environment, graphdef: nnx.GraphDef, graphstate: nnx.GraphState, key: jax.Array, num_steps_epoch: int, unroll: int = 8) Tuple['Environment', nnx.GraphState, jax.Array, 'TrajectoryData'][source][source]#

Roll out \(T = \text{num_steps_epoch}\) environment steps using jax.lax.scan().

Parameters:
  • env (Environment) – The trainer carrying model state.

  • graphdef (nnx.GraphDef) – Python part of the nnx model

  • graphstate (nnx.GraphState) – State of the nnx model

  • key (jax.Array) – Jax random key

  • num_steps_epoch (int) – Number of steps to roll out.

  • unroll (int) – Number of loop iterations to unroll for compilation speed.

Returns:

The final trainer and a TrajectoryData instance whose fields are stacked along time (leading dimension \(T = \text{num_steps_epoch}\)).

Return type:

Tuple[Environment, nnx.GraphState, jax.Array, TrajectoryData]

class jaxdem.rl.TrajectoryData(*, obs: Array, action: Array, value: Array, log_prob: Array, ratio: Array, reward: Array, done: Array)[source]#

Bases: object

Container for rollout data (single step or stacked across time).

obs: Array#

Observations.

action: Array#

Actions sampled from the policy.

value: Array#

Baseline value estimates \(V(s_t)\).

log_prob: Array#

Behavior-policy log-probabilities \(\log \pi_b(a_t \mid s_t)\) at collection time.

ratio: Array#

Ratio between behavior-policy probabilities \(\exp\big( \log \pi_\theta(a_t \mid s_t) - \log \pi_{\theta_\text{old}}(a_t \mid s_t) \big)\).

reward: Array#

Immediate rewards \(r_t\).

done: Array#

Episode-termination flags (boolean).

class jaxdem.rl.ActionSpace[source]#

Bases: Factory

Registry/namespace for action-space constraints implemented as `distrax.Bijector`s.

These bijectors are intended to be wrapped around a base policy distribution (e.g., MultivariateNormalDiag) via distrax.Transformed, so that sampling and log-probabilities are correctly adjusted using the bijector’s forward_and_log_det / inverse_and_log_det methods. See Distrax/TFP bijector interface for details on shape semantics and event_ndims_in/out.

Example

To define a custom action space, inherit from distrax.Bijector and ActionSpace and implement its abstract methods:

>>> @ActionSpace.register("myCustomActionSpace")
>>> class MyCustomActionSpace(distrax.Bijector, ActionSpace):
        ...
property kws: Dict[source]#

Modules

actionSpaces

Interface for defining bijectors used to constraint the policy probability distribution.

envWrappers

Contains wrappers for modifying rl environments.

environments

Reinforcement-learning environment interface.

models

Interface for defining reinforcement learning models.

trainers

Interface for defining reinforcement learning model trainers.