jaxdem.rl.models.LSTM#
Implementation of reinforcement learning models based on a single layer LSTM.
Classes
|
A recurrent actor–critic with an MLP encoder and an LSTM torso. |
- class jaxdem.rl.models.LSTM.LSTMActorCritic(*args: Any, **kwargs: Any)[source][source]#
Bases:
Model,ModuleA recurrent actor–critic with an MLP encoder and an LSTM torso.
This model encodes observations with a small feed-forward network, passes the features through a single-layer LSTM, and decodes the LSTM hidden state with linear policy/value heads.
Calling modes
- Sequence mode (training): time-major input
xwith shape (T, B, obs_dim)produces a distribution and value for every step: policy outputs(T, B, action_space_size)and values(T, B, 1). The LSTM carry is initialized to zeros for the sequence.
- Sequence mode (training): time-major input
- Single-step mode (evaluation/rollout): input
xwith shape (..., obs_dim)uses and updates a persistent LSTM carry stored on the module (self.h,self.c); outputs have shape(..., action_space_size)and(..., 1). Usereset_carry()to clear state between episodes. Carry needs to be reset every new trajectory.
- Single-step mode (evaluation/rollout): input
- Parameters:
observation_space_size (int) – Flattened observation size (
obs_dim).action_space_size (int) – Number of action dimensions.
key (nnx.Rngs) – Random number generator(s) for parameter initialization.
hidden_features (int) – Width of the encoder output (and LSTM input).
lstm_features (int) – LSTM hidden/state size. Also the feature size consumed by the heads.
activation (Callable) – Activation function applied inside the encoder.
action_space (distrax.Bijector | ActionSpace | None, default=None) – Bijector to constrain the policy probability distribution.
- obs_dim#
Observation dimensionality expected on the last axis of
x.- Type:
int
- lstm_features#
LSTM hidden/state size.
- Type:
int
- encoder#
MLP that maps
obs_dim -> hidden_features.- Type:
nnx.Sequential
- cell#
LSTM cell with
in_features = hidden_featuresandhidden_features = lstm_features.- Type:
rnn.OptimizedLSTMCell
- actor_mu#
Linear layer mapping LSTM features to the policy distribution means.
- Type:
nnx.Linear
- actor_sigma#
Linear layer mapping LSTM features to the policy distribution standard deviations if actor_sigma_head is true, else independent parameter.
- Type:
nnx.Sequential
- critic#
Linear head mapping LSTM features to a scalar value.
- Type:
nnx.Linear
- bij#
Action-space bijector; scalar bijectors are automatically lifted with
Block(ndims=1)for vector actions.- Type:
distrax.Bijector
- h, c
Persistent LSTM carry used by single-step evaluation. Shapes are
(..., lstm_features)and are resized lazily to match the leading batch/agent dimensions.- Type:
nnx.Variable
- reset(shape: Tuple, mask: Array | None = None)[source][source]#
Reset the persistent LSTM carry.
If self.h.value.shape != (*lead_shape, H), allocate fresh zeros once.
- Otherwise, zero in-place:
if mask is None: zero everything
if mask is provided: zero masked entries along axis=0
- Parameters:
shape (tuple[int, ...]) – Shape of the observation (input) tensor.
mask (optional bool array) – Mask per environment (axis=0) to conditionally reset the carry.