jaxdem.rl.models#

Interface for defining reinforcement learning models.

Classes

Model(*args, **kwargs)

The base interface for defining reinforcement learning models.

class jaxdem.rl.models.ActorCritic(*args: Any, **kwargs: Any)[source]#

Bases: Model, Module

An actor-critic model with separate networks for the actor and critic.

Unlike SharedActorCritic, this model uses two independent feedforward networks: - Actor torso: processes observations into features for the policy. - Critic torso: processes observations into features for the value function.

Parameters:
  • observation_space (int) – Shape of the observation space (excluding batch dimension).

  • action_space (ActionSpace) – Shape of the action space (for continuous) or number of discrete actions.

  • key (nnx.Rngs) – Random number generator(s) for parameter initialization.

  • actor_architecture (Sequence[int]) – Sizes of the hidden layers in the actor torso.

  • critic_architecture (Sequence[int]) – Sizes of the hidden layers in the critic torso.

  • in_scale (float) – Scaling factor for orthogonal initialization of hidden layers.

  • actor_scale (float) – Scaling factor for orthogonal initialization of the actor head.

  • critic_scale (float) – Scaling factor for orthogonal initialization of the critic head.

  • activation (Callable) – Activation function applied between hidden layers.

  • action_space – Bijector to constrain the policy probability distribution (continuous only).

  • discrete (bool) – If True, use a categorical distribution for discrete actions. If False (default), use a Gaussian distribution for continuous actions.

actor_torso#

Feedforward network for the actor.

Type:

nnx.Sequential

critic#

Feedforward network for the critic.

Type:

nnx.Sequential

actor_mu#

Linear layer mapping actor_torso’s features to the policy distribution means (continuous) or logits (discrete).

Type:

nnx.Linear

actor_sigma#

Linear layer mapping actor torso features to the policy distribution standard deviations if actor_sigma_head is true, else independent parameter. Only used for continuous actions.

Type:

nnx.Sequential

bij#

Bijector for constraining the action space (continuous only).

Type:

distrax.Bijector

discrete#

Whether this model uses discrete actions.

Type:

bool

property metadata: dict[str, Any][source]#
class jaxdem.rl.models.LSTMActorCritic(*args: Any, **kwargs: Any)[source]#

Bases: Model, Module

A recurrent actor–critic with an MLP encoder and an LSTM torso.

This model encodes observations with a small feed-forward network, passes the features through a single-layer LSTM, and decodes the LSTM hidden state with linear policy/value heads.

Calling modes

  • Sequence mode (training): time-major input x with shape (T, B, obs_dim) produces a distribution and value for every step: policy outputs (T, B, action_space_size) and values (T, B, 1). The LSTM carry is initialized to zeros.

  • Single-step mode (evaluation/rollout): input x with shape (..., obs_dim) uses and updates a persistent LSTM carry stored on the module (self.h, self.c); outputs have shape (..., action_space_size) and (..., 1). Use reset() to clear state between episodes.

Parameters:
  • observation_space_size (int) – Flattened observation size (obs_dim).

  • action_space_size (int) – Number of action dimensions (for continuous) or number of discrete actions.

  • key (nnx.Rngs) – Random number generator(s) for parameter initialization.

  • hidden_features (int) – Width of the encoder output (and LSTM input).

  • lstm_features (int) – LSTM hidden/state size. Also the feature size consumed by the policy and value heads.

  • activation (Callable) – Activation function applied inside the encoder.

  • action_space (distrax.Bijector | ActionSpace | None) – Bijector to constrain the policy probability distribution (continuous only).

  • cell_type (type[rnn.OptimizedLSTMCell]) – LSTM cell class used for the recurrent layer.

  • remat (bool) – If True, wrap the LSTM scan body with jax.checkpoint to reduce memory in sequence mode.

  • actor_sigma_head (bool) – If True, the standard deviation is produced by a learned head on the LSTM output; otherwise an independent log-std parameter is used. Only used for continuous actions.

  • carry_leading_shape (tuple[int, ...]) – Leading dimensions for the persistent carry tensors h and c. Typically () at construction; resized lazily at runtime to match the batch shape.

  • discrete (bool) – If True, use a categorical distribution for discrete actions. If False (default), use a Gaussian distribution for continuous actions.

obs_dim#

Observation dimensionality expected on the last axis of x.

Type:

int

lstm_features#

LSTM hidden/state size.

Type:

int

encoder#

MLP that maps obs_dim hidden_features.

Type:

nnx.Sequential

cell#

LSTM cell with in_features = hidden_features and hidden_features = lstm_features.

Type:

rnn.OptimizedLSTMCell

actor_mu#

Linear layer mapping LSTM features to the policy distribution means (continuous) or logits (discrete).

Type:

nnx.Linear

actor_sigma#

Maps LSTM features to the policy standard deviations (learned head when actor_sigma_head=True, else independent parameter). Only used for continuous actions.

Type:

Callable[[jax.Array], jax.Array]

critic#

Linear head mapping LSTM features to a scalar value.

Type:

nnx.Linear

bij#

Action-space bijector; scalar bijectors are automatically lifted with Block(ndims=1) for vector actions (continuous only).

Type:

distrax.Bijector

h, c

Persistent LSTM carry used by single-step evaluation. Shapes are (..., lstm_features) and are resized lazily to match the leading batch/agent dimensions.

Type:

nnx.Variable

discrete#

Whether this model uses discrete actions.

Type:

bool

property metadata: dict[str, Any][source]#
reset(shape: tuple[int, ...], mask: Array | None = None) None[source]#

Reset the persistent LSTM carry.

  • If self.h.value.shape != (*lead_shape, H), allocate fresh zeros once.

  • Otherwise, zero in-place:
    • if mask is None: zero everything

    • if mask is provided: zero masked entries along axis=0

Parameters:
  • shape (tuple[int, ...]) – Shape of the observation (input) tensor.

  • mask (optional bool array) – Mask per environment (axis=0) to conditionally reset the carry.

class jaxdem.rl.models.Model(*args: Any, **kwargs: Any)#

Bases: Factory, Module, ABC

The base interface for defining reinforcement learning models. Acts as a namespace.

Models map observations to an action distribution and a value estimate.

Example:#

To define a custom model, inherit from Model and implement its abstract methods:

>>> @Model.register("myCustomModel")
>>> class MyCustomModel(Model):
        ...
property metadata: dict[str, Any][source]#
reset(shape: tuple[int, ...], mask: Array | None = None) None[source]#

Reset the persistent LSTM carry.

Parameters:
  • shape (tuple[int, ...]) – Leading dims for the carry, e.g. (num_envs, num_agents).

  • mask (optional bool array) – True where to reset entries. Shape (num_envs)

class jaxdem.rl.models.SharedActorCritic(*args: Any, **kwargs: Any)[source]#

Bases: Model

A shared-parameter dense actor-critic model.

This model uses a common feedforward network (the “shared torso”) to process observations, and then branches into two separate linear heads: - Actor head: outputs the mean of a Gaussian action distribution (continuous)

or logits for a categorical distribution (discrete).

  • Critic head: outputs a scalar value estimate of the state.

Parameters:
  • observation_space (int) – Shape of the observation space (excluding batch dimension).

  • action_space (ActionSpace) – Shape of the action space (for continuous) or number of discrete actions.

  • key (nnx.Rngs) – Random number generator(s) for parameter initialization.

  • architecture (Sequence[int]) – Sizes of the hidden layers in the shared network.

  • in_scale (float) – Scaling factor for orthogonal initialization of the shared network layers.

  • actor_scale (float) – Scaling factor for orthogonal initialization of the actor head.

  • critic_scale (float) – Scaling factor for orthogonal initialization of the critic head.

  • activation (Callable) – JIT-compatible activation function applied between hidden layers.

  • action_space – Bijector to constrain the policy probability distribution (continuous only).

  • discrete (bool) – If True, use a categorical distribution for discrete actions. If False (default), use a Gaussian distribution for continuous actions.

network#

The shared feedforward network (torso).

Type:

nnx.Sequential

actor_mu#

Linear layer mapping shared features to the policy distribution means (continuous) or logits (discrete).

Type:

nnx.Linear

actor_sigma#

Linear layer mapping shared features to the policy distribution standard deviations if actor_sigma_head is true, else independent parameter. Only used for continuous actions.

Type:

nnx.Sequential

critic#

Linear layer mapping shared features to the value estimate.

Type:

nnx.Linear

bij#

Bijector for constraining the action space (continuous only).

Type:

distrax.Bijector

discrete#

Whether this model uses discrete actions.

Type:

bool

property metadata: dict[str, Any][source]#

Modules

LSTM

Implementation of reinforcement learning models based on a single layer LSTM.

MLP

Implementation of reinforcement learning models based on simple MLPs.