jaxdem.rl.models.LSTM#
Implementation of reinforcement learning models based on a single layer LSTM.
Classes
|
A recurrent actor–critic with an MLP encoder and an LSTM torso. |
- class jaxdem.rl.models.LSTM.LSTMActorCritic(*args: Any, **kwargs: Any)[source]#
Bases:
Model,ModuleA recurrent actor–critic with an MLP encoder and an LSTM torso.
This model encodes observations with a small feed-forward network, passes the features through a single-layer LSTM, and decodes the LSTM hidden state with linear policy/value heads.
Calling modes
Sequence mode (training): time-major input
xwith shape(T, B, obs_dim)produces a distribution and value for every step: policy outputs(T, B, action_space_size)and values(T, B, 1). The LSTM carry is initialized to zeros.Single-step mode (evaluation/rollout): input
xwith shape(..., obs_dim)uses and updates a persistent LSTM carry stored on the module (self.h,self.c); outputs have shape(..., action_space_size)and(..., 1). Usereset()to clear state between episodes.
- Parameters:
observation_space_size (int) – Flattened observation size (
obs_dim).action_space_size (int) – Number of action dimensions (for continuous) or number of discrete actions.
key (nnx.Rngs) – Random number generator(s) for parameter initialization.
hidden_features (int) – Width of the encoder output (and LSTM input).
lstm_features (int) – LSTM hidden/state size. Also the feature size consumed by the policy and value heads.
activation (Callable) – Activation function applied inside the encoder.
action_space (distrax.Bijector | ActionSpace | None) – Bijector to constrain the policy probability distribution (continuous only).
cell_type (type[rnn.OptimizedLSTMCell]) – LSTM cell class used for the recurrent layer.
remat (bool) – If
True, wrap the LSTM scan body withjax.checkpointto reduce memory in sequence mode.actor_sigma_head (bool) – If
True, the standard deviation is produced by a learned head on the LSTM output; otherwise an independent log-std parameter is used. Only used for continuous actions.carry_leading_shape (tuple[int, ...]) – Leading dimensions for the persistent carry tensors
handc. Typically()at construction; resized lazily at runtime to match the batch shape.discrete (bool) – If True, use a categorical distribution for discrete actions. If False (default), use a Gaussian distribution for continuous actions.
- obs_dim#
Observation dimensionality expected on the last axis of
x.- Type:
int
- lstm_features#
LSTM hidden/state size.
- Type:
int
- encoder#
MLP that maps
obs_dim → hidden_features.- Type:
nnx.Sequential
- cell#
LSTM cell with
in_features = hidden_featuresandhidden_features = lstm_features.- Type:
rnn.OptimizedLSTMCell
- actor_mu#
Linear layer mapping LSTM features to the policy distribution means (continuous) or logits (discrete).
- Type:
nnx.Linear
- actor_sigma#
Maps LSTM features to the policy standard deviations (learned head when
actor_sigma_head=True, else independent parameter). Only used for continuous actions.- Type:
Callable[[jax.Array], jax.Array]
- critic#
Linear head mapping LSTM features to a scalar value.
- Type:
nnx.Linear
- bij#
Action-space bijector; scalar bijectors are automatically lifted with
Block(ndims=1)for vector actions (continuous only).- Type:
distrax.Bijector
- h, c
Persistent LSTM carry used by single-step evaluation. Shapes are
(..., lstm_features)and are resized lazily to match the leading batch/agent dimensions.- Type:
nnx.Variable
- discrete#
Whether this model uses discrete actions.
- Type:
bool
- reset(shape: tuple[int, ...], mask: Array | None = None) None[source]#
Reset the persistent LSTM carry.
If self.h.value.shape != (*lead_shape, H), allocate fresh zeros once.
- Otherwise, zero in-place:
if mask is None: zero everything
if mask is provided: zero masked entries along axis=0
- Parameters:
shape (tuple[int, ...]) – Shape of the observation (input) tensor.
mask (optional bool array) – Mask per environment (axis=0) to conditionally reset the carry.