jaxdem.rl.models#
Interface for defining reinforcement learning models.
Classes
|
The base interface for defining reinforcement learning models. |
- class jaxdem.rl.models.Model(*args: Any, **kwargs: Any)[source]#
Bases:
Factory
,Module
,ABC
The base interface for defining reinforcement learning models. Acts as a name space.
Models map observations to an action distribution and a value estimate.
Example
To define a custom model, inherit from
Model
and implement its abstract methods:>>> @Model.register("myCustomModel") >>> class MyCustomModel(Model): ...
Bases:
Model
A shared-parameter dense actor-critic model.
This model uses a common feedforward network (the “shared torso”) to process observations, and then branches into two separate linear heads: - Actor head: outputs the mean of a Gaussian action distribution. - Critic head: outputs a scalar value estimate of the state.
- Parameters:
observation_space (Sequence[int]) – Shape of the observation space (excluding batch dimension).
action_space (ActionSpace) – Shape of the action space.
key (nnx.Rngs) – Random number generator(s) for parameter initialization.
architecture (Sequence[int]) – Sizes of the hidden layers in the shared network.
in_scale (float) – Scaling factor for orthogonal initialization of the shared network layers.
actor_scale (float) – Scaling factor for orthogonal initialization of the actor head.
critic_scale (float) – Scaling factor for orthogonal initialization of the critic head.
activation (Callable) – Activation function applied between hidden layers.
action_space – Bijector to constrain the policy probability distribution
The shared feedforward network (torso).
- Type:
nnx.Sequential
Linear layer mapping shared features to action means.
- Type:
nnx.Linear
Linear layer mapping shared features to a scalar value estimate.
- Type:
nnx.Linear
Learnable log standard deviation for the Gaussian action distribution.
- Type:
nnx.Param
Bijector for constraining the action space.
- Type:
Distrax.bijector:
- class jaxdem.rl.models.ActorCritic(*args: Any, **kwargs: Any)[source][source]#
Bases:
Model
,Module
An actor-critic model with separate networks for the actor and critic.
Unlike SharedActorCritic, this model uses two independent feedforward networks: - Actor torso: processes observations into features for the policy. - Critic torso: processes observations into features for the value function.
- Parameters:
observation_space (Sequence[int]) – Shape of the observation space (excluding batch dimension).
action_space (ActionSpace) – Shape of the action space.
key (nnx.Rngs) – Random number generator(s) for parameter initialization.
actor_architecture (Sequence[int]) – Sizes of the hidden layers in the actor torso.
critic_architecture (Sequence[int]) – Sizes of the hidden layers in the critic torso.
in_scale (float) – Scaling factor for orthogonal initialization of hidden layers.
actor_scale (float) – Scaling factor for orthogonal initialization of the actor head.
critic_scale (float) – Scaling factor for orthogonal initialization of the critic head.
activation (Callable) – Activation function applied between hidden layers.
action_space – Bijector to constrain the policy probability distribution
- actor_torso#
Feedforward network for the actor.
- Type:
nnx.Sequential
- critic_torso#
Feedforward network for the critic.
- Type:
nnx.Sequential
- actor#
Linear layer mapping actor features to action means.
- Type:
nnx.Linear
- critic#
Linear layer mapping critic features to a scalar value estimate.
- Type:
nnx.Linear
- log_std[source]#
Learnable log standard deviation for the Gaussian action distribution.
- Type:
nnx.Param
- bij#
Bijector for constraining the action space.
- Type:
Distrax.bijector:
- class jaxdem.rl.models.LSTMActorCritic(*args: Any, **kwargs: Any)[source][source]#
Bases:
Model
,Module
A recurrent actor–critic with an MLP encoder and an LSTM torso.
This model encodes observations with a small feed-forward network, passes the features through a single-layer LSTM, and decodes the LSTM hidden state with linear policy/value heads.
Calling modes
- Sequence mode (training): time-major input
x
with shape (T, B, obs_dim)
produces a distribution and value for every step: policy outputs(T, B, action_space_size)
and values(T, B, 1)
. The LSTM carry is initialized to zeros for the sequence.
- Sequence mode (training): time-major input
- Single-step mode (evaluation/rollout): input
x
with shape (..., obs_dim)
uses and updates a persistent LSTM carry stored on the module (self.h
,self.c
); outputs have shape(..., action_space_size)
and(..., 1)
. Usereset_carry()
to clear state between episodes. Carry needs to be reset every new trajectory.
- Single-step mode (evaluation/rollout): input
- Parameters:
observation_space_size (int) – Flattened observation size (
obs_dim
).action_space_size (int) – Number of action dimensions.
key (nnx.Rngs) – Random number generator(s) for parameter initialization.
hidden_features (int) – Width of the encoder output (and LSTM input).
lstm_features (int) – LSTM hidden/state size. Also the feature size consumed by the heads.
activation (Callable) – Activation function applied inside the encoder.
action_space (distrax.Bijector | ActionSpace | None, default=None) – Bijector to constrain the policy probability distribution.
- obs_dim#
Observation dimensionality expected on the last axis of
x
.- Type:
int
- lstm_features#
LSTM hidden/state size.
- Type:
int
- encoder#
MLP that maps
obs_dim -> hidden_features
.- Type:
nnx.Sequential
- cell#
LSTM cell with
in_features = hidden_features
andhidden_features = lstm_features
.- Type:
rnn.OptimizedLSTMCell
- actor#
Linear head mapping LSTM features to action means.
- Type:
nnx.Linear
- critic#
Linear head mapping LSTM features to a scalar value.
- Type:
nnx.Linear
- bij#
Action-space bijector; scalar bijectors are automatically lifted with
Block(ndims=1)
for vector actions.- Type:
distrax.Bijector
- h, c
Persistent LSTM carry used by single-step evaluation. Shapes are
(..., lstm_features)
and are resized lazily to match the leading batch/agent dimensions.- Type:
nnx.Variable
- reset(shape: Tuple, mask: Array | None = None)[source][source]#
Reset the persistent LSTM carry.
If self.h.value.shape != (*lead_shape, H), allocate fresh zeros once.
- Otherwise, zero in-place:
if mask is None: zero everything (without materializing a zeros tensor)
if mask is provided: zero only masked entries (mask may be shape lead_shape or (*lead_shape, 1) / (*lead_shape, H))
- Parameters:
lead_shape (tuple[int, ...]) – Leading dims for the carry, e.g. (num_envs, num_agents).
mask (optional bool array) – True where you want to reset entries. If shape is lead_shape, it will be expanded across the features dim.
Modules