jaxdem.rl.models.LSTM#
Implementation of reinforcement learning models based on a single layer LSTM.
Classes
|
A recurrent actor–critic with an MLP encoder and an LSTM torso. |
- class jaxdem.rl.models.LSTM.LSTMActorCritic(*args: Any, **kwargs: Any)[source][source]#
Bases:
Model
,Module
A recurrent actor–critic with an MLP encoder and an LSTM torso.
This model encodes observations with a small feed-forward network, passes the features through a single-layer LSTM, and decodes the LSTM hidden state with linear policy/value heads.
Calling modes
- Sequence mode (training): time-major input
x
with shape (T, B, obs_dim)
produces a distribution and value for every step: policy outputs(T, B, action_space_size)
and values(T, B, 1)
. The LSTM carry is initialized to zeros for the sequence.
- Sequence mode (training): time-major input
- Single-step mode (evaluation/rollout): input
x
with shape (..., obs_dim)
uses and updates a persistent LSTM carry stored on the module (self.h
,self.c
); outputs have shape(..., action_space_size)
and(..., 1)
. Usereset_carry()
to clear state between episodes. Carry needs to be reset every new trajectory.
- Single-step mode (evaluation/rollout): input
- Parameters:
observation_space_size (int) – Flattened observation size (
obs_dim
).action_space_size (int) – Number of action dimensions.
key (nnx.Rngs) – Random number generator(s) for parameter initialization.
hidden_features (int) – Width of the encoder output (and LSTM input).
lstm_features (int) – LSTM hidden/state size. Also the feature size consumed by the heads.
activation (Callable) – Activation function applied inside the encoder.
action_space (distrax.Bijector | ActionSpace | None, default=None) – Bijector to constrain the policy probability distribution.
- obs_dim#
Observation dimensionality expected on the last axis of
x
.- Type:
int
- lstm_features#
LSTM hidden/state size.
- Type:
int
- encoder#
MLP that maps
obs_dim -> hidden_features
.- Type:
nnx.Sequential
- cell#
LSTM cell with
in_features = hidden_features
andhidden_features = lstm_features
.- Type:
rnn.OptimizedLSTMCell
- actor#
Linear head mapping LSTM features to action means.
- Type:
nnx.Linear
- critic#
Linear head mapping LSTM features to a scalar value.
- Type:
nnx.Linear
- bij#
Action-space bijector; scalar bijectors are automatically lifted with
Block(ndims=1)
for vector actions.- Type:
distrax.Bijector
- h, c
Persistent LSTM carry used by single-step evaluation. Shapes are
(..., lstm_features)
and are resized lazily to match the leading batch/agent dimensions.- Type:
nnx.Variable
- reset(shape: Tuple, mask: Array | None = None)[source][source]#
Reset the persistent LSTM carry.
If self.h.value.shape != (*lead_shape, H), allocate fresh zeros once.
- Otherwise, zero in-place:
if mask is None: zero everything (without materializing a zeros tensor)
if mask is provided: zero only masked entries (mask may be shape lead_shape or (*lead_shape, 1) / (*lead_shape, H))
- Parameters:
lead_shape (tuple[int, ...]) – Leading dims for the carry, e.g. (num_envs, num_agents).
mask (optional bool array) – True where you want to reset entries. If shape is lead_shape, it will be expanded across the features dim.