jaxdem.rl.models#

Interface for defining reinforcement learning models.

Classes

Model(*args, **kwargs)

The base interface for defining reinforcement learning models.

class jaxdem.rl.models.Model(*args: Any, **kwargs: Any)[source]#

Bases: Factory, Module, ABC

The base interface for defining reinforcement learning models. Acts as a name space.

Models map observations to an action distribution and a value estimate.

Example

To define a custom model, inherit from Model and implement its abstract methods:

>>> @Model.register("myCustomModel")
>>> class MyCustomModel(Model):
        ...
property log_std: Param[source]#
property metadata: Dict[source]#
reset(shape: Tuple, mask: Array | None = None)[source][source]#

Reset the persistent LSTM carry.

Parameters:
  • lead_shape (tuple[int, ...]) – Leading dims for the carry, e.g. (num_envs, num_agents).

  • mask (optional bool array) – True where to reset entries. Shape (num_envs)

class jaxdem.rl.models.SharedActorCritic(*args: Any, **kwargs: Any)[source][source]#

Bases: Model

A shared-parameter dense actor-critic model.

This model uses a common feedforward network (the “shared torso”) to process observations, and then branches into two separate linear heads: - Actor head: outputs the mean of a Gaussian action distribution. - Critic head: outputs a scalar value estimate of the state.

Parameters:
  • observation_space (Sequence[int]) – Shape of the observation space (excluding batch dimension).

  • action_space (ActionSpace) – Shape of the action space.

  • key (nnx.Rngs) – Random number generator(s) for parameter initialization.

  • architecture (Sequence[int]) – Sizes of the hidden layers in the shared network.

  • in_scale (float) – Scaling factor for orthogonal initialization of the shared network layers.

  • actor_scale (float) – Scaling factor for orthogonal initialization of the actor head.

  • critic_scale (float) – Scaling factor for orthogonal initialization of the critic head.

  • activation (Callable) – Activation function applied between hidden layers.

  • action_space – Bijector to constrain the policy probability distribution

network#

The shared feedforward network (torso).

Type:

nnx.Sequential

actor#

Linear layer mapping shared features to action means.

Type:

nnx.Linear

critic#

Linear layer mapping shared features to a scalar value estimate.

Type:

nnx.Linear

log_std[source]#

Learnable log standard deviation for the Gaussian action distribution.

Type:

nnx.Param

bij#

Bijector for constraining the action space.

Type:

Distrax.bijector:

property log_std: Param[source]#
property metadata: Dict[source]#
classmethod registry_name() str[source]#
property type_name: str[source]#
class jaxdem.rl.models.ActorCritic(*args: Any, **kwargs: Any)[source][source]#

Bases: Model, Module

An actor-critic model with separate networks for the actor and critic.

Unlike SharedActorCritic, this model uses two independent feedforward networks: - Actor torso: processes observations into features for the policy. - Critic torso: processes observations into features for the value function.

Parameters:
  • observation_space (Sequence[int]) – Shape of the observation space (excluding batch dimension).

  • action_space (ActionSpace) – Shape of the action space.

  • key (nnx.Rngs) – Random number generator(s) for parameter initialization.

  • actor_architecture (Sequence[int]) – Sizes of the hidden layers in the actor torso.

  • critic_architecture (Sequence[int]) – Sizes of the hidden layers in the critic torso.

  • in_scale (float) – Scaling factor for orthogonal initialization of hidden layers.

  • actor_scale (float) – Scaling factor for orthogonal initialization of the actor head.

  • critic_scale (float) – Scaling factor for orthogonal initialization of the critic head.

  • activation (Callable) – Activation function applied between hidden layers.

  • action_space – Bijector to constrain the policy probability distribution

actor_torso#

Feedforward network for the actor.

Type:

nnx.Sequential

critic_torso#

Feedforward network for the critic.

Type:

nnx.Sequential

actor#

Linear layer mapping actor features to action means.

Type:

nnx.Linear

critic#

Linear layer mapping critic features to a scalar value estimate.

Type:

nnx.Linear

log_std[source]#

Learnable log standard deviation for the Gaussian action distribution.

Type:

nnx.Param

bij#

Bijector for constraining the action space.

Type:

Distrax.bijector:

property log_std: Param[source]#
property metadata: Dict[source]#
classmethod registry_name() str[source]#
property type_name: str[source]#
class jaxdem.rl.models.LSTMActorCritic(*args: Any, **kwargs: Any)[source][source]#

Bases: Model, Module

A recurrent actor–critic with an MLP encoder and an LSTM torso.

This model encodes observations with a small feed-forward network, passes the features through a single-layer LSTM, and decodes the LSTM hidden state with linear policy/value heads.

Calling modes

  • Sequence mode (training): time-major input x with shape

    (T, B, obs_dim) produces a distribution and value for every step: policy outputs (T, B, action_space_size) and values (T, B, 1). The LSTM carry is initialized to zeros for the sequence.

  • Single-step mode (evaluation/rollout): input x with shape

    (..., obs_dim) uses and updates a persistent LSTM carry stored on the module (self.h, self.c); outputs have shape (..., action_space_size) and (..., 1). Use reset_carry() to clear state between episodes. Carry needs to be reset every new trajectory.

Parameters:
  • observation_space_size (int) – Flattened observation size (obs_dim).

  • action_space_size (int) – Number of action dimensions.

  • key (nnx.Rngs) – Random number generator(s) for parameter initialization.

  • hidden_features (int) – Width of the encoder output (and LSTM input).

  • lstm_features (int) – LSTM hidden/state size. Also the feature size consumed by the heads.

  • activation (Callable) – Activation function applied inside the encoder.

  • action_space (distrax.Bijector | ActionSpace | None, default=None) – Bijector to constrain the policy probability distribution.

obs_dim#

Observation dimensionality expected on the last axis of x.

Type:

int

lstm_features#

LSTM hidden/state size.

Type:

int

encoder#

MLP that maps obs_dim -> hidden_features.

Type:

nnx.Sequential

cell#

LSTM cell with in_features = hidden_features and hidden_features = lstm_features.

Type:

rnn.OptimizedLSTMCell

actor#

Linear head mapping LSTM features to action means.

Type:

nnx.Linear

critic#

Linear head mapping LSTM features to a scalar value.

Type:

nnx.Linear

log_std[source]#

Learnable log standard deviation for the policy.

Type:

nnx.Param

bij#

Action-space bijector; scalar bijectors are automatically lifted with Block(ndims=1) for vector actions.

Type:

distrax.Bijector

h, c

Persistent LSTM carry used by single-step evaluation. Shapes are (..., lstm_features) and are resized lazily to match the leading batch/agent dimensions.

Type:

nnx.Variable

property log_std: Param[source]#
property metadata: Dict[source]#
classmethod registry_name() str[source]#
reset(shape: Tuple, mask: Array | None = None)[source][source]#

Reset the persistent LSTM carry.

  • If self.h.value.shape != (*lead_shape, H), allocate fresh zeros once.

  • Otherwise, zero in-place:
    • if mask is None: zero everything (without materializing a zeros tensor)

    • if mask is provided: zero only masked entries (mask may be shape lead_shape or (*lead_shape, 1) / (*lead_shape, H))

Parameters:
  • lead_shape (tuple[int, ...]) – Leading dims for the carry, e.g. (num_envs, num_agents).

  • mask (optional bool array) – True where you want to reset entries. If shape is lead_shape, it will be expanded across the features dim.

property type_name: str[source]#

Modules

LSTM

Implementation of reinforcement learning models based on a single layer LSTM.

MLP

Implementation of reinforcement learning models based on simple MLPs.