jaxdem.rl#
JaxDEM reinforcement learning (RL) module. Contains ML models, environments and trainers with RL algorithms like PPO.
- class jaxdem.rl.Environment(state: State, system: System, env_params: Dict[str, Any], max_num_agents: int = 0, action_space_size: int = 0, action_space_shape: Tuple[int, ...] = (), observation_space_size: int = 0)[source]#
Bases:
Factory
,ABC
Defines the interface for reinforcement-learning environments.
Let A be the number of agents (A ≥ 1). Single-agent environments still use A=1.
Observations and actions are flattened per agent to fixed sizes. Use
action_space_shape
to reshape inside the environment if needed.
Required shapes
Observation:
(A, observation_space_size)
Action (input to
step()
):(A, action_space_size)
Reward:
(A,)
Done: scalar boolean for the whole environment
TODO:
Truncated data field: per-agent termination flag
Render method
Example
To define a custom environment, inherit from
Environment
and implement the abstract methods:>>> @Environment.register("MyCustomEnv") >>> @jax.tree_util.register_dataclass >>> @dataclass(slots=True, frozen=True) >>> class MyCustomEnv(Environment): ...
- env_params: Dict[str, Any]#
Environment-specific parameters.
- max_num_agents: int#
Maximum number of active agents in the environment.
- action_space_size: int#
Flattened action size per agent. Actions passed to
step()
have shape(A, action_space_size)
.
- action_space_shape: Tuple[int, ...]#
Original per-agent action shape (useful for reshaping inside the environment).
- observation_space_size: int#
Flattened observation size per agent.
observation()
returns shape(A, observation_space_size)
.
- abstractmethod static done(env: Environment) Array [source][source]#
Returns a boolean indicating whether the environment has ended.
- Parameters:
env (Environment) – The current environment.
- Returns:
A bool indicating when the environment ended
- Return type:
jax.Array
- static info(env: Environment) Dict[str, Any] [source][source]#
Return auxiliary diagnostic information.
By default, returns an empty dict. Subclasses may override to provide environment specific information.
- Parameters:
env (Environment) – The current state of the environment.
- Returns:
A dictionary with additional information about the environment.
- Return type:
Dict
- abstractmethod static observation(env: Environment) Array [source][source]#
Returns the per-agent observation vector.
- Parameters:
env (Environment) – The current environment.
- Returns:
Vector corresponding to the environment observation.
- Return type:
jax.Array
- abstractmethod static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex) Environment [source][source]#
Initialize the environment to a valid start state.
- Parameters:
env (Environment) – Instance of the environment.
key (jax.random.PRNGKey) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
- static reset_if_done(env: Environment, done: Array, key: Array | ndarray | bool | number | bool | int | float | complex) Environment [source][source]#
Conditionally resets the environment if the environment has reached a terminal state.
This method checks the done flag and, if True, calls the environment’s reset method to reinitialize the state. Otherwise, it returns the current environment unchanged.
- Parameters:
env (Environment) – The current environment instance.
done (jax.Array) – A boolean flag indicating whether the environment has reached a terminal state.
key (jax.random.PRNGKey) – JAX random number generator key used for reinitialization.
- Returns:
Either the freshly reset environment (if done is True) or the unchanged environment (if done is False).
- Return type:
- abstractmethod static reward(env: Environment) Array [source][source]#
Returns the per-agent immediate rewards.
- Parameters:
env (Environment) – The current environment.
- Returns:
Vector corresponding to all the agent’s rewards based on the current environment state.
- Return type:
jax.Array
- abstractmethod static step(env: Environment, action: Array) Environment [source][source]#
Advance the simulation by one step using per-agent actions.
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
- jaxdem.rl.vectorise_env(env: Environment) Environment [source][source]#
Promote an environment instance to a parallel version by applying jax.vmap(…) to its static methods.
- jaxdem.rl.clip_action_env(env: Environment, min_val: float = -1.0, max_val: float = 1.0) Environment [source][source]#
Wrap an environment so that its step method clips the action before calling the original step.
- jaxdem.rl.is_wrapped(env: Environment) bool [source][source]#
Check whether an environment instance is a wrapped environment.
- Parameters:
env (Environment) – The environment instance to check.
- Returns:
True if the environment is wrapped (i.e., has a _base_env_cls attribute), False otherwise.
- Return type:
bool
- jaxdem.rl.unwrap(env: Environment) Environment [source][source]#
Unwrap an environment to its original base class while preserving all current field values.
- Parameters:
env (Environment) – The wrapped environment instance.
- Returns:
A new instance of the original base environment class with the same field values as the wrapped instance.
- Return type:
- class jaxdem.rl.Model(*args: Any, **kwargs: Any)[source]#
Bases:
Factory
,Module
,ABC
The base interface for defining reinforcement learning models. Acts as a name space.
Models map observations to an action distribution and a value estimate.
Example
To define a custom model, inherit from
Model
and implement its abstract methods:>>> @Model.register("myCustomModel") >>> class MyCustomModel(Model): ...
- class jaxdem.rl.Trainer(env: Environment, graphdef: nnx.GraphDef, graphstate: nnx.GraphState, key: ArrayLike, advantage_gamma: jax.Array, advantage_lambda: jax.Array, advantage_rho_clip: jax.Array, advantage_c_clip: jax.Array)[source]#
Bases:
Factory
,ABC
Base class for reinforcement learning trainers.
This class holds the environment and model state (Flax NNX GraphDef/GraphState). It provides rollout utilities (
step()
,trajectory_rollout()
) and a general advantage computation method (compute_advantages()
). Subclasses must implement algorithm-specific training logic inepoch()
.Example
To define a custom trainer, inherit from
Trainer
and implement its abstract methods:>>> @Trainer.register("myCustomTrainer") >>> @jax.tree_util.register_dataclass >>> @dataclass(slots=True, frozen=True) >>> class MyCustomTrainer(Trainer): ...
- env: Environment#
Environment object.
- graphdef: nnx.GraphDef#
Static graph definition of the model/optimizer.
- graphstate: nnx.GraphState#
Mutable state (parameters, optimizer state, RNGs, etc.).
- key: ArrayLike#
PRNGKey used to sample actions and for other stochastic operations.
- advantage_gamma: jax.Array#
Discount factor \(\gamma \in [0, 1]\).
- advantage_lambda: jax.Array#
Generalized Advantage Estimation parameter \(\lambda \in [0, 1]\).
- advantage_rho_clip: jax.Array#
V-trace \(\bar{\rho}\) (importance weight clip for the TD term).
- advantage_c_clip: jax.Array#
V-trace \(\bar{c}\) (importance weight clip for the recursion/trace term).
- static compute_advantages(td: TrajectoryData, advantage_rho_clip: Array, advantage_c_clip: Array, advantage_gamma: Array, advantage_lambda: Array, unroll: int = 8) TrajectoryData [source][source]#
Compute advantages and return targets with V-trace-style off-policy correction or generalized advantage estimation (GAE).
Let the behavior policy be \(\pi_b\) and the target policy be \(\pi\). Define importance ratios per step:
\[\rho_t = \exp\big( \log \pi(a_t \mid s_t) - \log \pi_b(a_t \mid s_t) \big)\]and their clipped versions \(\bar{\rho}, \bar{c}\):
\[\hat{\rho}_t = \min(\rho_t, \bar{\rho}), \quad \hat{c}_t = \min(\rho_t, \bar{c}).\]We form a TD-like residual with an off-policy correction:
\[\delta_t = \hat{\rho}_t \, r_t + \gamma V(s_{t+1})(1 - \text{done}_t) - V(s_t)\]and propagate a GAE-style trace using \(\hat{c}_t\):
\[A_t = \delta_t + \gamma \lambda (1 - \text{done}_t) \hat{c}_t A_{t+1}\]Finally, the return targets are:
\[\text{returns}_t = A_t + V(s_t)\]Notes
When \(\pi_b = \pi\) (i.e.
TrajectoryData.log_prob == TrajectoryData.new_log_prob
) and \(\bar{\rho} = \bar{c} = 1\), this function reduces to standard GAE.
- Returns:
TrajectoryData
with newadvantage
andreturns
.- Return type:
References
Schulman et al., High-Dimensional Continuous Control Using Generalized Advantage Estimation, 2015/2016
Espeholt et al., IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures, 2018
- abstractmethod static epoch(tr: Trainer, epoch: Array | ndarray | bool | number | bool | int | float | complex) Any [source][source]#
Run one training epoch.
Subclasses must implement this with their algorithm-specific logic.
- static reset_model(tr: Trainer, shape: Sequence[int] | None = None, mask: Array | None = None) Trainer [source][source]#
Reset a model’s persistent recurrent state (e.g., LSTM carry) for all environments/agents and persist the mutation back into the trainer.
- Parameters:
tr (Trainer) – Trainer carrying the environment and NNX graph state. The target carry shape is inferred as
(tr.num_envs, tr.env.max_num_agents)
if not specified.mask (jax.Array, optional) – Boolean mask selecting which (env, agent) entries to reset. A value of
True
resets that entry. The mask may be shape(num_envs, num_agents)
or any shape broadcastable to it. IfNone
, all entries are reset.
- Returns:
A new trainer with the updated
graphstate
.- Return type:
- static step(tr: Trainer) Tuple[Trainer, TrajectoryData] [source][source]#
Take one environment step and record a single-step trajectory.
- Returns:
Updated trainer and the new single-step trajectory record. Trajectory data shape: (N_envs, N_agents, *)
- Return type:
- static trajectory_rollout(tr: Trainer, num_steps_epoch: int, unroll: int = 8) Tuple[Trainer, TrajectoryData] [source][source]#
Roll out \(T = \text{num_steps_epoch}\) environment steps using
jax.lax.scan()
.- Parameters:
tr (Trainer) – The trainer carrying model state.
num_steps_epoch (int) – Number of steps to roll out.
unroll (int) – Number of loop iterations to unroll for compilation speed.
- Returns:
The final trainer and a
TrajectoryData
instance whose fields are stacked along time (leading dimension \(T = \text{num_steps_epoch}\)).- Return type:
- class jaxdem.rl.TrajectoryData(*, obs: Array, action: Array, reward: Array, done: Array, value: Array, log_prob: Array, new_log_prob: Array, advantage: Array, returns: Array)[source]#
Bases:
object
Container for rollout data (single step or stacked across time).
- obs: Array#
Observations.
- action: Array#
Actions sampled from the policy.
- reward: Array#
Immediate rewards \(r_t\).
- done: Array#
Episode-termination flags (boolean).
- value: Array#
Baseline value estimates \(V(s_t)\).
- log_prob: Array#
Behavior-policy log-probabilities \(\log \pi_b(a_t \mid s_t)\) at collection time.
- new_log_prob: Array#
Target-policy log-probabilities \(\log \pi(a_t \mid s_t)\) after policy update.
Fill with
log_prob
during on-policy collection; must be recomputed after updates.
- advantage: Array#
Advantages \(A_t\).
- returns: Array#
Return targets (e.g., GAE or V-trace targets).
- class jaxdem.rl.ActionSpace[source]#
Bases:
Factory
Registry/namespace for action-space constraints implemented as `distrax.Bijector`s.
These bijectors are intended to be wrapped around a base policy distribution (e.g., MultivariateNormalDiag) via distrax.Transformed, so that sampling and log-probabilities are correctly adjusted using the bijector’s forward_and_log_det / inverse_and_log_det methods. See Distrax/TFP bijector interface for details on shape semantics and event_ndims_in/out.
Example
To define a custom action space, inherit from
distrax.Bijector
andActionSpace
and implement its abstract methods:>>> @ActionSpace.register("myCustomActionSpace") >>> class MyCustomActionSpace(distrax.Bijector, ActionSpace): ...
Modules
Interface for defining bijectors used to constraint the policy probability distribution. |
|
Contains wrappers for modifying rl environments. |
|
Reinforcement-learning environment interface. |
|
Interface for defining reinforcement learning models. |
|
Interface for defining reinforcement learning model trainers. |