jaxdem.rl#
JaxDEM reinforcement learning (RL) module. Contains ML models, environments and trainers with RL algorithms like PPO.
- class jaxdem.rl.Environment(state: State, system: System, env_params: Dict[str, Any])[source]#
Bases:
Factory,ABCDefines the interface for reinforcement-learning environments.
Let A be the number of agents (A ≥ 1). Single-agent environments still use A=1.
Observations and actions are flattened per agent to fixed sizes. Use
action_space_shapeto reshape inside the environment if needed.
Required shapes
Observation:
(A, observation_space_size)Action (input to
step()):(A, action_space_size)Reward:
(A,)Done: scalar boolean for the whole environment
TODO:
Truncated data field: per-agent termination flag
Render method
Example
To define a custom environment, inherit from
Environmentand implement the abstract methods:>>> @Environment.register("MyCustomEnv") >>> @jax.tree_util.register_dataclass >>> @dataclass(slots=True) >>> class MyCustomEnv(Environment): ...
- env_params: Dict[str, Any]#
Environment-specific parameters.
- property action_space_shape: Tuple[int][source]#
Original per-agent action shape (useful for reshaping inside the environment).
- property action_space_size: int[source]#
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
- abstractmethod static done(env: Environment) Array[source][source]#
Returns a boolean indicating whether the environment has ended.
- Parameters:
env (Environment) – The current environment.
- Returns:
A bool indicating when the environment ended
- Return type:
jax.Array
- static info(env: Environment) Dict[str, Any][source][source]#
Return auxiliary diagnostic information.
By default, returns an empty dict. Subclasses may override to provide environment specific information.
- Parameters:
env (Environment) – The current state of the environment.
- Returns:
A dictionary with additional information about the environment.
- Return type:
Dict
- abstractmethod static observation(env: Environment) Array[source][source]#
Returns the per-agent observation vector.
- Parameters:
env (Environment) – The current environment.
- Returns:
Vector corresponding to the environment observation.
- Return type:
jax.Array
- property observation_space_size: int[source]#
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).
- abstractmethod static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source][source]#
Initialize the environment to a valid start state.
- Parameters:
env (Environment) – Instance of the environment.
key (jax.random.PRNGKey) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
- static reset_if_done(env: Environment, done: Array, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source][source]#
Conditionally resets the environment if the environment has reached a terminal state.
This method checks the done flag and, if True, calls the environment’s reset method to reinitialize the state. Otherwise, it returns the current environment unchanged.
- Parameters:
env (Environment) – The current environment instance.
done (jax.Array) – A boolean flag indicating whether the environment has reached a terminal state.
key (jax.random.PRNGKey) – JAX random number generator key used for reinitialization.
- Returns:
Either the freshly reset environment (if done is True) or the unchanged environment (if done is False).
- Return type:
- abstractmethod static reward(env: Environment) Array[source][source]#
Returns the per-agent immediate rewards.
- Parameters:
env (Environment) – The current environment.
- Returns:
Vector corresponding to all the agent’s rewards based on the current environment state.
- Return type:
jax.Array
- abstractmethod static step(env: Environment, action: Array) Environment[source][source]#
Advance the simulation by one step using per-agent actions.
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
- jaxdem.rl.vectorise_env(env: Environment) Environment[source][source]#
Promote an environment instance to a parallel version by applying jax.vmap(…) to its static methods.
- jaxdem.rl.clip_action_env(env: Environment, min_val: float = -1.0, max_val: float = 1.0) Environment[source][source]#
Wrap an environment so that its step method clips the action before calling the original step.
- jaxdem.rl.is_wrapped(env: Environment) bool[source][source]#
Check whether an environment instance is a wrapped environment.
- Parameters:
env (Environment) – The environment instance to check.
- Returns:
True if the environment is wrapped (i.e., has a _base_env_cls attribute), False otherwise.
- Return type:
bool
- jaxdem.rl.unwrap(env: Environment) Environment[source][source]#
Unwrap an environment to its original base class while preserving all current field values.
- Parameters:
env (Environment) – The wrapped environment instance.
- Returns:
A new instance of the original base environment class with the same field values as the wrapped instance.
- Return type:
- class jaxdem.rl.Model(*args: Any, **kwargs: Any)[source]#
Bases:
Factory,Module,ABCThe base interface for defining reinforcement learning models. Acts as a name space.
Models map observations to an action distribution and a value estimate.
Example
To define a custom model, inherit from
Modeland implement its abstract methods:>>> @Model.register("myCustomModel") >>> class MyCustomModel(Model): ...
- class jaxdem.rl.Trainer(env: Environment, graphdef: nnx.GraphDef, graphstate: nnx.GraphState, key: ArrayLike, advantage_gamma: jax.Array, advantage_lambda: jax.Array, advantage_rho_clip: jax.Array, advantage_c_clip: jax.Array)[source]#
Bases:
Factory,ABCBase class for reinforcement learning trainers.
This class holds the environment and model state (Flax NNX GraphDef/GraphState). It provides rollout utilities (
step(),trajectory_rollout()) and a general advantage computation method (compute_advantages()). Subclasses must implement algorithm-specific training logic inepoch().Example
To define a custom trainer, inherit from
Trainerand implement its abstract methods:>>> @Trainer.register("myCustomTrainer") >>> @jax.tree_util.register_dataclass >>> @dataclass(slots=True) >>> class MyCustomTrainer(Trainer): ...
- env: Environment#
Environment object.
- graphdef: nnx.GraphDef#
Static graph definition of the model/optimizer.
- graphstate: nnx.GraphState#
Mutable state (parameters, optimizer state, RNGs, etc.).
- key: ArrayLike#
PRNGKey used to sample actions and for other stochastic operations.
- advantage_gamma: jax.Array#
Discount factor \(\gamma \in [0, 1]\).
- advantage_lambda: jax.Array#
Generalized Advantage Estimation parameter \(\lambda \in [0, 1]\).
- advantage_rho_clip: jax.Array#
V-trace \(\bar{\rho}\) (importance weight clip for the TD term).
- advantage_c_clip: jax.Array#
V-trace \(\bar{c}\) (importance weight clip for the recursion/trace term).
- static compute_advantages(value: Array, reward: Array, ratio: Array, done: Array, advantage_rho_clip: Array, advantage_c_clip: Array, advantage_gamma: Array, advantage_lambda: Array, unroll: int = 8) Tuple[Array, Array][source][source]#
Compute V-trace/GAE advantages and return targets.
Given a policy \(\pi\), define per-step importance ratios and clipped versions:
\[\rho_t = \exp\big( \log \pi_\theta(a_t \mid s_t) - \log \pi_{\theta_\text{old}}(a_t \mid s_t) \big)\]and their clipped versions \(\hat{\rho}, \hat{c}\):
\[\hat{\rho}_t = \min(\rho_t, \bar{\rho}), \quad \hat{c}_t = \min(\rho_t, \bar{c}).\]We form a TD-like residual with an off-policy correction:
\[\delta_t = \hat{\rho}_t \, r_t + \gamma V(s_{t+1})(1 - \text{done}_t) - V(s_t)\]and propagate a GAE-style trace using \(\hat{c}_t\):
\[A_t = \delta_t + \gamma \lambda (1 - \text{done}_t) \hat{c}_t A_{t+1}\]Finally, the return targets are:
\[\text{returns}_t = A_t + V(s_t)\]Notes
When \(\pi_\theta = \pi_{\theta_\text{old}}\) (i.e.
ratio==1) and \(\bar{\rho} = \bar{c} = 1\), this function reduces to standard GAE.- Returns:
Computed advantage and retuns
- Return type:
Tuple[jax.Array]
References
Schulman et al., High-Dimensional Continuous Control Using Generalized Advantage Estimation, 2015/2016
Espeholt et al., IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures, 2018
- abstractmethod static epoch(tr: Trainer, epoch: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Any[source][source]#
Run one training epoch.
Subclasses must implement this with their algorithm-specific logic.
- static step(env: Environment, graphdef: nnx.GraphDef, graphstate: nnx.GraphState, key: jax.Array) Tuple[Tuple['Environment', nnx.GraphState, jax.Array], 'TrajectoryData'][source][source]#
Take one environment step and record a single-step trajectory.
- Parameters:
env (Environment) – The trainer carrying model state.
graphdef (nnx.GraphDef) – Python part of the nnx model
graphstate (nnx.GraphState) – State of the nnx model
key (jax.Array) – Jax random key
- Returns:
Updated state and the new single-step trajectory. Trajectory data is shaped (N_envs, N_agents, …).
- Return type:
Tuple[Tuple[Environment, nnx.GraphState, jax.Array], TrajectoryData]
- abstractmethod static train(tr) Any[source][source]#
Training loop
Subclasses must implement this with their algorithm-specific logic.
- static trajectory_rollout(env: Environment, graphdef: nnx.GraphDef, graphstate: nnx.GraphState, key: jax.Array, num_steps_epoch: int, unroll: int = 8) Tuple['Environment', nnx.GraphState, jax.Array, 'TrajectoryData'][source][source]#
Roll out \(T = \text{num_steps_epoch}\) environment steps using
jax.lax.scan().- Parameters:
env (Environment) – The trainer carrying model state.
graphdef (nnx.GraphDef) – Python part of the nnx model
graphstate (nnx.GraphState) – State of the nnx model
key (jax.Array) – Jax random key
num_steps_epoch (int) – Number of steps to roll out.
unroll (int) – Number of loop iterations to unroll for compilation speed.
- Returns:
The final trainer and a
TrajectoryDatainstance whose fields are stacked along time (leading dimension \(T = \text{num_steps_epoch}\)).- Return type:
Tuple[Environment, nnx.GraphState, jax.Array, TrajectoryData]
- class jaxdem.rl.TrajectoryData(*, obs: Array, action: Array, value: Array, log_prob: Array, ratio: Array, reward: Array, done: Array)[source]#
Bases:
objectContainer for rollout data (single step or stacked across time).
- obs: Array#
Observations.
- action: Array#
Actions sampled from the policy.
- value: Array#
Baseline value estimates \(V(s_t)\).
- log_prob: Array#
Behavior-policy log-probabilities \(\log \pi_b(a_t \mid s_t)\) at collection time.
- ratio: Array#
Ratio between behavior-policy probabilities \(\exp\big( \log \pi_\theta(a_t \mid s_t) - \log \pi_{\theta_\text{old}}(a_t \mid s_t) \big)\).
- reward: Array#
Immediate rewards \(r_t\).
- done: Array#
Episode-termination flags (boolean).
- class jaxdem.rl.ActionSpace[source]#
Bases:
FactoryRegistry/namespace for action-space constraints implemented as `distrax.Bijector`s.
These bijectors are intended to be wrapped around a base policy distribution (e.g., MultivariateNormalDiag) via distrax.Transformed, so that sampling and log-probabilities are correctly adjusted using the bijector’s forward_and_log_det / inverse_and_log_det methods. See Distrax/TFP bijector interface for details on shape semantics and event_ndims_in/out.
Example
To define a custom action space, inherit from
distrax.BijectorandActionSpaceand implement its abstract methods:>>> @ActionSpace.register("myCustomActionSpace") >>> class MyCustomActionSpace(distrax.Bijector, ActionSpace): ...
Modules
Interface for defining bijectors used to constraint the policy probability distribution. |
|
Contains wrappers for modifying rl environments. |
|
Reinforcement-learning environment interface. |
|
Interface for defining reinforcement learning models. |
|
Interface for defining reinforcement learning model trainers. |