# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Implementation of reinforcement learning models based on a single layer LSTM.
"""
from __future__ import annotations
import jax
import jax.numpy as jnp
from typing import Tuple, Callable, Dict, Optional, cast
from functools import partial
from flax import nnx
import flax.nnx.nn.recurrent as rnn
import distrax
from . import Model
from ..actionSpaces import ActionSpace
from ...utils import encode_callable
[docs]
@Model.register("LSTMActorCritic")
class LSTMActorCritic(Model, nnx.Module):
"""
A recurrent actor–critic with an MLP encoder and an LSTM torso.
This model encodes observations with a small feed-forward network, passes the
features through a single-layer LSTM, and decodes the LSTM hidden state with
linear policy/value heads.
**Calling modes**
- **Sequence mode (training)**: time-major input ``x`` with shape
``(T, B, obs_dim)`` produces a distribution and value for **every** step:
policy outputs ``(T, B, action_space_size)`` and values ``(T, B, 1)``.
The LSTM carry is initialized to zeros for the sequence.
- **Single-step mode (evaluation/rollout)**: input ``x`` with shape
``(..., obs_dim)`` uses and updates a persistent LSTM carry stored on the
module (``self.h``, ``self.c``); outputs have shape
``(..., action_space_size)`` and ``(..., 1)``. Use :meth:`reset_carry` to
clear state between episodes. Carry needs to be reset every new trajectory.
Parameters
----------
observation_space_size : int
Flattened observation size (``obs_dim``).
action_space_size : int
Number of action dimensions.
key : nnx.Rngs
Random number generator(s) for parameter initialization.
hidden_features : int
Width of the encoder output (and LSTM input).
lstm_features : int
LSTM hidden/state size. Also the feature size consumed by the heads.
activation : Callable
Activation function applied inside the encoder.
action_space : distrax.Bijector | ActionSpace | None, default=None
Bijector to constrain the policy probability distribution.
Attributes
----------
obs_dim : int
Observation dimensionality expected on the last axis of ``x``.
lstm_features : int
LSTM hidden/state size.
encoder : nnx.Sequential
MLP that maps ``obs_dim -> hidden_features``.
cell : rnn.OptimizedLSTMCell
LSTM cell with ``in_features = hidden_features`` and ``hidden_features = lstm_features``.
actor_mu : nnx.Linear
Linear layer mapping LSTM features to the policy distribution means.
actor_sigma : nnx.Sequential
Linear layer mapping LSTM features to the policy distribution standard deviations if actor_sigma_head is true, else independent parameter.
critic : nnx.Linear
Linear head mapping LSTM features to a scalar value.
bij : distrax.Bijector
Action-space bijector; scalar bijectors are automatically lifted with ``Block(ndims=1)`` for vector actions.
h, c : nnx.Variable
Persistent LSTM carry used by single-step evaluation. Shapes are
``(..., lstm_features)`` and are resized lazily to match the leading batch/agent dimensions.
"""
__slots__ = ()
def __init__(
self,
observation_space_size: int,
action_space_size: int,
key: nnx.Rngs,
hidden_features: int = 64,
lstm_features: int = 128,
dropout_rate: float = 0.1,
activation: Callable = nnx.gelu,
action_space: distrax.Bijector | ActionSpace | None = None,
cell_type=rnn.OptimizedLSTMCell,
remat: bool = False,
actor_sigma_head: bool = False,
carry_leading_shape: Tuple[int, ...] = (),
):
super().__init__()
self.obs_dim = int(observation_space_size)
self.action_space_size = int(action_space_size)
self.hidden_features = int(hidden_features)
self.lstm_features = int(lstm_features)
self.dropout_rate = float(dropout_rate)
self.activation = activation
self.remat = remat
self.cell_type = cell_type
self.actor_sigma_head = actor_sigma_head
self.encoder = nnx.Sequential(
nnx.Linear(
self.obs_dim,
self.hidden_features,
kernel_init=nnx.initializers.orthogonal(jnp.sqrt(2.0)),
bias_init=nnx.initializers.constant(0.0),
rngs=key,
),
self.activation,
)
self.cell = self.cell_type(
in_features=self.hidden_features,
hidden_features=self.lstm_features,
rngs=key,
)
if self.remat:
self.cell.__call__ = nnx.remat(self.cell.__call__)
self.rnn = rnn.RNN(self.cell)
self.actor_mu = nnx.Linear(
in_features=self.lstm_features,
out_features=self.action_space_size,
kernel_init=nnx.initializers.orthogonal(1.0),
bias_init=nnx.initializers.constant(0.0),
rngs=key,
)
self._log_std = nnx.Param(jnp.zeros((1, self.action_space_size)))
self._actor_sigma = nnx.Sequential(
nnx.Linear(
in_features=self.lstm_features,
out_features=self.action_space_size,
kernel_init=nnx.initializers.orthogonal(0.01),
bias_init=nnx.initializers.constant(-1.0),
rngs=key,
),
jax.nn.softplus,
)
if self.actor_sigma_head:
self.actor_sigma = lambda x: self._actor_sigma(x)
else:
self.actor_sigma = lambda x: jnp.exp(self._log_std.value)
self.critic = nnx.Linear(
in_features=self.lstm_features,
out_features=1,
kernel_init=nnx.initializers.orthogonal(0.01),
bias_init=nnx.initializers.constant(0.0),
rngs=key,
)
# self.dropout = nnx.Dropout(self.dropout_rate, rngs=key)
if action_space is None:
action_space = ActionSpace.create("Free")
# Check if bijector is scalar
bij = cast(distrax.Bijector, action_space)
if getattr(bij, "event_ndims_in", 0) == 0:
bij = distrax.Block(bij, ndims=1)
self.bij = bij
# Persistent carry for SINGLE-STEP usage (lives in nnx.State)
# shape will be lazily set to x.shape[:-1] + (lstm_features,)
H = int(self.lstm_features)
lead = tuple(carry_leading_shape)
self.h = nnx.Variable(jnp.zeros(lead + (H,), dtype=float))
self.c = nnx.Variable(jnp.zeros(lead + (H,), dtype=float))
@property
def metadata(self) -> Dict:
return dict(
observation_space_size=self.obs_dim,
action_space_size=self.action_space_size,
hidden_features=self.hidden_features,
lstm_features=self.lstm_features,
dropout_rate=self.dropout_rate,
activation=encode_callable(self.activation),
action_space_type=self.bij.type_name,
action_space_kws=self.bij.kws,
remat=self.remat,
actor_sigma_head=self.actor_sigma_head,
cell_type=encode_callable(self.cell_type),
carry_leading_shape=self.h.value.shape[:-1],
)
[docs]
@partial(jax.named_call, name="LSTMActorCritic.reset")
def reset(self, shape: Tuple, mask: jax.Array | None = None):
"""
Reset the persistent LSTM carry.
- If `self.h.value.shape != (*lead_shape, H)`, allocate fresh zeros once.
- Otherwise, zero in-place:
* if `mask is None`: zero everything
* if `mask` is provided: zero masked entries along axis=0
Parameters
-----------
shape : tuple[int, ...]
Shape of the observation (input) tensor.
mask : optional bool array
Mask per environment (axis=0) to conditionally reset the carry.
"""
H = self.lstm_features
target_shape = (*shape[:-1], H)
# If shape changed, allocate once and return
if self.h.value.shape != target_shape:
self.h.value = jnp.zeros(target_shape, dtype=float)
self.c.value = jnp.zeros(target_shape, dtype=float)
return
# If shape matches and everything needs resetting
if mask is None:
self.h.value *= 0.0
self.c.value *= 0.0
return
# If shapes matches and masked reset
mask = mask.reshape((mask.shape[0],) + (1,) * (self.h.value.ndim - 1))
self.h.value = jnp.where(mask, 0.0, self.h.value)
self.c.value = jnp.where(mask, 0.0, self.c.value)
@partial(jax.named_call, name="LSTMActorCritic.__call__")
def __call__(
self, x: jax.Array, sequence: bool = False
) -> Tuple[distrax.Distribution, jax.Array]:
"""
Remember to reset the carry each time starting a new trajectory.
Accepts:
- sequence = False: single step: x shape (..., obs_dim) -> uses persistent carry
- sequence = True : x shape (T, B, obs_dim) (time-major) -> starts from zero carry
"""
if x.shape[-1] != self.obs_dim:
raise ValueError(f"Expected last dim {self.obs_dim}, got {x.shape}")
feats = self.encoder(x) # (..., hidden)
if sequence:
y = self.rnn(feats, time_major=True)
else:
batch = feats.shape[:-1]
target = (*batch, self.lstm_features)
# Lazily allocate carry via the cell API when shape mismatches
if self.h.value.shape != target:
c0, h0 = self.cell.initialize_carry(
input_shape=feats.shape, rngs=nnx.Rngs(0)
)
self.c.value, self.h.value = c0, h0 # carry order = (c, h)
(c1, h1), y = self.cell((self.c.value, self.h.value), feats) # (..., H)
self.c.value, self.h.value = c1, h1
h = y
pi = distrax.MultivariateNormalDiag(self.actor_mu(h), self.actor_sigma(h))
pi = distrax.Transformed(pi, self.bij)
return pi, self.critic(h)
__all__ = ["LSTMActorCritic"]