jaxdem.rl#

JaxDEM reinforcement learning (RL) module. Contains ML models, environments and trainers with RL algorithms like PPO.

class jaxdem.rl.Environment(state: State, system: System, env_params: Dict[str, Any], max_num_agents: int = 0, action_space_size: int = 0, action_space_shape: Tuple[int, ...] = (), observation_space_size: int = 0)[source]#

Bases: Factory, ABC

Defines the interface for reinforcement-learning environments.

  • Let A be the number of agents (A ≥ 1). Single-agent environments still use A=1.

  • Observations and actions are flattened per agent to fixed sizes. Use action_space_shape to reshape inside the environment if needed.

Required shapes

  • Observation: (A, observation_space_size)

  • Action (input to step()): (A, action_space_size)

  • Reward: (A,)

  • Done: scalar boolean for the whole environment

TODO:

  • Truncated data field: per-agent termination flag

  • Render method

Example

To define a custom environment, inherit from Environment and implement the abstract methods:

>>> @Environment.register("MyCustomEnv")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True, frozen=True)
>>> class MyCustomEnv(Environment):
    ...
state: State#

Simulation state.

system: System#

Simulation system configuration.

env_params: Dict[str, Any]#

Environment-specific parameters.

max_num_agents: int#

Maximum number of active agents in the environment.

action_space_size: int#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

action_space_shape: Tuple[int, ...]#

Original per-agent action shape (useful for reshaping inside the environment).

observation_space_size: int#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).

abstractmethod static done(env: Environment) Array[source][source]#

Returns a boolean indicating whether the environment has ended.

Parameters:

env (Environment) – The current environment.

Returns:

A bool indicating when the environment ended

Return type:

jax.Array

static info(env: Environment) Dict[str, Any][source][source]#

Return auxiliary diagnostic information.

By default, returns an empty dict. Subclasses may override to provide environment specific information.

Parameters:

env (Environment) – The current state of the environment.

Returns:

A dictionary with additional information about the environment.

Return type:

Dict

abstractmethod static observation(env: Environment) Array[source][source]#

Returns the per-agent observation vector.

Parameters:

env (Environment) – The current environment.

Returns:

Vector corresponding to the environment observation.

Return type:

jax.Array

abstractmethod static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source][source]#

Initialize the environment to a valid start state.

Parameters:
  • env (Environment) – Instance of the environment.

  • key (jax.random.PRNGKey) – JAX random number generator key.

Returns:

Freshly initialized environment.

Return type:

Environment

static reset_if_done(env: Environment, done: Array, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source][source]#

Conditionally resets the environment if the environment has reached a terminal state.

This method checks the done flag and, if True, calls the environment’s reset method to reinitialize the state. Otherwise, it returns the current environment unchanged.

Parameters:
  • env (Environment) – The current environment instance.

  • done (jax.Array) – A boolean flag indicating whether the environment has reached a terminal state.

  • key (jax.random.PRNGKey) – JAX random number generator key used for reinitialization.

Returns:

Either the freshly reset environment (if done is True) or the unchanged environment (if done is False).

Return type:

Environment

abstractmethod static reward(env: Environment) Array[source][source]#

Returns the per-agent immediate rewards.

Parameters:

env (Environment) – The current environment.

Returns:

Vector corresponding to all the agent’s rewards based on the current environment state.

Return type:

jax.Array

abstractmethod static step(env: Environment, action: Array) Environment[source][source]#

Advance the simulation by one step using per-agent actions.

Parameters:
  • env (Environment) – The current environment.

  • action (jax.Array) – The vector of actions each agent in the environment should take.

Returns:

The updated environment state.

Return type:

Environment

jaxdem.rl.vectorise_env(env: Environment) Environment[source][source]#

Promote an environment instance to a parallel version by applying jax.vmap(…) to its static methods.

jaxdem.rl.clip_action_env(env: Environment, min_val: float = -1.0, max_val: float = 1.0) Environment[source][source]#

Wrap an environment so that its step method clips the action before calling the original step.

jaxdem.rl.is_wrapped(env: Environment) bool[source][source]#

Check whether an environment instance is a wrapped environment.

Parameters:

env (Environment) – The environment instance to check.

Returns:

True if the environment is wrapped (i.e., has a _base_env_cls attribute), False otherwise.

Return type:

bool

jaxdem.rl.unwrap(env: Environment) Environment[source][source]#

Unwrap an environment to its original base class while preserving all current field values.

Parameters:

env (Environment) – The wrapped environment instance.

Returns:

A new instance of the original base environment class with the same field values as the wrapped instance.

Return type:

Environment

class jaxdem.rl.Model(*args: Any, **kwargs: Any)[source]#

Bases: Factory, Module, ABC

The base interface for defining reinforcement learning models. Acts as a name space.

Models map observations to an action distribution and a value estimate.

Example

To define a custom model, inherit from Model and implement its abstract methods:

>>> @Model.register("myCustomModel")
>>> class MyCustomModel(Model):
        ...
property log_std: Param[source]#
property metadata: Dict[source]#
reset(shape: Tuple, mask: Array | None = None)[source][source]#

Reset the persistent LSTM carry.

Parameters:
  • lead_shape (tuple[int, ...]) – Leading dims for the carry, e.g. (num_envs, num_agents).

  • mask (optional bool array) – True where to reset entries. Shape (num_envs)

class jaxdem.rl.Trainer(env: Environment, graphdef: nnx.GraphDef, graphstate: nnx.GraphState, key: ArrayLike, advantage_gamma: jax.Array, advantage_lambda: jax.Array, advantage_rho_clip: jax.Array, advantage_c_clip: jax.Array)[source]#

Bases: Factory, ABC

Base class for reinforcement learning trainers.

This class holds the environment and model state (Flax NNX GraphDef/GraphState). It provides rollout utilities (step(), trajectory_rollout()) and a general advantage computation method (compute_advantages()). Subclasses must implement algorithm-specific training logic in epoch().

Example

To define a custom trainer, inherit from Trainer and implement its abstract methods:

>>> @Trainer.register("myCustomTrainer")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True, frozen=True)
>>> class MyCustomTrainer(Trainer):
        ...
env: Environment#

Environment object.

graphdef: nnx.GraphDef#

Static graph definition of the model/optimizer.

graphstate: nnx.GraphState#

Mutable state (parameters, optimizer state, RNGs, etc.).

key: ArrayLike#

PRNGKey used to sample actions and for other stochastic operations.

advantage_gamma: jax.Array#

Discount factor \(\gamma \in [0, 1]\).

advantage_lambda: jax.Array#

Generalized Advantage Estimation parameter \(\lambda \in [0, 1]\).

advantage_rho_clip: jax.Array#

V-trace \(\bar{\rho}\) (importance weight clip for the TD term).

advantage_c_clip: jax.Array#

V-trace \(\bar{c}\) (importance weight clip for the recursion/trace term).

static compute_advantages(td: TrajectoryData, advantage_rho_clip: Array, advantage_c_clip: Array, advantage_gamma: Array, advantage_lambda: Array, unroll: int = 8) TrajectoryData[source][source]#

Compute advantages and return targets with V-trace-style off-policy correction or generalized advantage estimation (GAE).

Let the behavior policy be \(\pi_b\) and the target policy be \(\pi\). Define importance ratios per step:

\[\rho_t = \exp\big( \log \pi(a_t \mid s_t) - \log \pi_b(a_t \mid s_t) \big)\]

and their clipped versions \(\bar{\rho}, \bar{c}\):

\[\hat{\rho}_t = \min(\rho_t, \bar{\rho}), \quad \hat{c}_t = \min(\rho_t, \bar{c}).\]

We form a TD-like residual with an off-policy correction:

\[\delta_t = \hat{\rho}_t \, r_t + \gamma V(s_{t+1})(1 - \text{done}_t) - V(s_t)\]

and propagate a GAE-style trace using \(\hat{c}_t\):

\[A_t = \delta_t + \gamma \lambda (1 - \text{done}_t) \hat{c}_t A_{t+1}\]

Finally, the return targets are:

\[\text{returns}_t = A_t + V(s_t)\]

Notes

  • When \(\pi_b = \pi\) (i.e. TrajectoryData.log_prob == TrajectoryData.new_log_prob) and \(\bar{\rho} = \bar{c} = 1\), this function reduces to standard GAE.

Returns:

TrajectoryData with new advantage and returns.

Return type:

(TrajectoryData)

References

  • Schulman et al., High-Dimensional Continuous Control Using Generalized Advantage Estimation, 2015/2016

  • Espeholt et al., IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures, 2018

abstractmethod static epoch(tr: Trainer, epoch: Array | ndarray | bool | number | bool | int | float | complex) Any[source][source]#

Run one training epoch.

Subclasses must implement this with their algorithm-specific logic.

property model[source]#

Return the live model rebuilt from (graphdef, graphstate).

property optimizer[source]#

Return the optimizer rebuilt from (graphdef, graphstate).

static reset_model(tr: Trainer, shape: Sequence[int] | None = None, mask: Array | None = None) Trainer[source][source]#

Reset a model’s persistent recurrent state (e.g., LSTM carry) for all environments/agents and persist the mutation back into the trainer.

Parameters:
  • tr (Trainer) – Trainer carrying the environment and NNX graph state. The target carry shape is inferred as (tr.num_envs, tr.env.max_num_agents) if not specified.

  • mask (jax.Array, optional) – Boolean mask selecting which (env, agent) entries to reset. A value of True resets that entry. The mask may be shape (num_envs, num_agents) or any shape broadcastable to it. If None, all entries are reset.

Returns:

A new trainer with the updated graphstate.

Return type:

Trainer

static step(tr: Trainer) Tuple[Trainer, TrajectoryData][source][source]#

Take one environment step and record a single-step trajectory.

Returns:

Updated trainer and the new single-step trajectory record. Trajectory data shape: (N_envs, N_agents, *)

Return type:

(Trainer, TrajectoryData)

abstractmethod static train(tr) Any[source][source]#

Training loop

static trajectory_rollout(tr: Trainer, num_steps_epoch: int, unroll: int = 8) Tuple[Trainer, TrajectoryData][source][source]#

Roll out \(T = \text{num_steps_epoch}\) environment steps using jax.lax.scan().

Parameters:
  • tr (Trainer) – The trainer carrying model state.

  • num_steps_epoch (int) – Number of steps to roll out.

  • unroll (int) – Number of loop iterations to unroll for compilation speed.

Returns:

The final trainer and a TrajectoryData instance whose fields are stacked along time (leading dimension \(T = \text{num_steps_epoch}\)).

Return type:

(Trainer, TrajectoryData)

class jaxdem.rl.TrajectoryData(*, obs: Array, action: Array, reward: Array, done: Array, value: Array, log_prob: Array, new_log_prob: Array, advantage: Array, returns: Array)[source]#

Bases: object

Container for rollout data (single step or stacked across time).

obs: Array#

Observations.

action: Array#

Actions sampled from the policy.

reward: Array#

Immediate rewards \(r_t\).

done: Array#

Episode-termination flags (boolean).

value: Array#

Baseline value estimates \(V(s_t)\).

log_prob: Array#

Behavior-policy log-probabilities \(\log \pi_b(a_t \mid s_t)\) at collection time.

new_log_prob: Array#

Target-policy log-probabilities \(\log \pi(a_t \mid s_t)\) after policy update.

Fill with log_prob during on-policy collection; must be recomputed after updates.

advantage: Array#

Advantages \(A_t\).

returns: Array#

Return targets (e.g., GAE or V-trace targets).

class jaxdem.rl.ActionSpace[source]#

Bases: Factory

Registry/namespace for action-space constraints implemented as `distrax.Bijector`s.

These bijectors are intended to be wrapped around a base policy distribution (e.g., MultivariateNormalDiag) via distrax.Transformed, so that sampling and log-probabilities are correctly adjusted using the bijector’s forward_and_log_det / inverse_and_log_det methods. See Distrax/TFP bijector interface for details on shape semantics and event_ndims_in/out.

Example

To define a custom action space, inherit from distrax.Bijector and ActionSpace and implement its abstract methods:

>>> @ActionSpace.register("myCustomActionSpace")
>>> class MyCustomActionSpace(distrax.Bijector, ActionSpace):
        ...
property kws: Dict[source]#

Modules

actionSpaces

Interface for defining bijectors used to constraint the policy probability distribution.

envWrappers

Contains wrappers for modifying rl environments.

environments

Reinforcement-learning environment interface.

models

Interface for defining reinforcement learning models.

trainers

Interface for defining reinforcement learning model trainers.