jaxdem.rl.models.LSTM#

Implementation of reinforcement learning models based on a single layer LSTM.

Classes

LSTMActorCritic(*args, **kwargs)

A recurrent actor–critic with an MLP encoder and an LSTM torso.

class jaxdem.rl.models.LSTM.LSTMActorCritic(*args: Any, **kwargs: Any)[source][source]#

Bases: Model, Module

A recurrent actor–critic with an MLP encoder and an LSTM torso.

This model encodes observations with a small feed-forward network, passes the features through a single-layer LSTM, and decodes the LSTM hidden state with linear policy/value heads.

Calling modes

  • Sequence mode (training): time-major input x with shape

    (T, B, obs_dim) produces a distribution and value for every step: policy outputs (T, B, action_space_size) and values (T, B, 1). The LSTM carry is initialized to zeros for the sequence.

  • Single-step mode (evaluation/rollout): input x with shape

    (..., obs_dim) uses and updates a persistent LSTM carry stored on the module (self.h, self.c); outputs have shape (..., action_space_size) and (..., 1). Use reset_carry() to clear state between episodes. Carry needs to be reset every new trajectory.

Parameters:
  • observation_space_size (int) – Flattened observation size (obs_dim).

  • action_space_size (int) – Number of action dimensions.

  • key (nnx.Rngs) – Random number generator(s) for parameter initialization.

  • hidden_features (int) – Width of the encoder output (and LSTM input).

  • lstm_features (int) – LSTM hidden/state size. Also the feature size consumed by the heads.

  • activation (Callable) – Activation function applied inside the encoder.

  • action_space (distrax.Bijector | ActionSpace | None, default=None) – Bijector to constrain the policy probability distribution.

obs_dim#

Observation dimensionality expected on the last axis of x.

Type:

int

lstm_features#

LSTM hidden/state size.

Type:

int

encoder#

MLP that maps obs_dim -> hidden_features.

Type:

nnx.Sequential

cell#

LSTM cell with in_features = hidden_features and hidden_features = lstm_features.

Type:

rnn.OptimizedLSTMCell

actor_mu#

Linear layer mapping LSTM features to the policy distribution means.

Type:

nnx.Linear

actor_sigma#

Linear layer mapping LSTM features to the policy distribution standard deviations if actor_sigma_head is true, else independent parameter.

Type:

nnx.Sequential

critic#

Linear head mapping LSTM features to a scalar value.

Type:

nnx.Linear

bij#

Action-space bijector; scalar bijectors are automatically lifted with Block(ndims=1) for vector actions.

Type:

distrax.Bijector

h, c

Persistent LSTM carry used by single-step evaluation. Shapes are (..., lstm_features) and are resized lazily to match the leading batch/agent dimensions.

Type:

nnx.Variable

property metadata: Dict[source]#
reset(shape: Tuple, mask: Array | None = None)[source][source]#

Reset the persistent LSTM carry.

  • If self.h.value.shape != (*lead_shape, H), allocate fresh zeros once.

  • Otherwise, zero in-place:
    • if mask is None: zero everything

    • if mask is provided: zero masked entries along axis=0

Parameters:
  • shape (tuple[int, ...]) – Shape of the observation (input) tensor.

  • mask (optional bool array) – Mask per environment (axis=0) to conditionally reset the carry.