jaxdem.rl.models#
Interface for defining reinforcement learning models.
Classes
|
The base interface for defining reinforcement learning models. |
- class jaxdem.rl.models.Model(*args: Any, **kwargs: Any)[source]#
Bases:
Factory,Module,ABCThe base interface for defining reinforcement learning models. Acts as a name space.
Models map observations to an action distribution and a value estimate.
Example
To define a custom model, inherit from
Modeland implement its abstract methods:>>> @Model.register("myCustomModel") >>> class MyCustomModel(Model): ...
Bases:
ModelA shared-parameter dense actor-critic model.
This model uses a common feedforward network (the “shared torso”) to process observations, and then branches into two separate linear heads: - Actor head: outputs the mean of a Gaussian action distribution. - Critic head: outputs a scalar value estimate of the state.
- Parameters:
observation_space (int) – Shape of the observation space (excluding batch dimension).
action_space (ActionSpace) – Shape of the action space.
key (nnx.Rngs) – Random number generator(s) for parameter initialization.
architecture (Sequence[int]) – Sizes of the hidden layers in the shared network.
in_scale (float) – Scaling factor for orthogonal initialization of the shared network layers.
actor_scale (float) – Scaling factor for orthogonal initialization of the actor head.
critic_scale (float) – Scaling factor for orthogonal initialization of the critic head.
activation (Callable) – Activation Jit compatible function applied between hidden layers.
action_space – Bijector to constrain the policy probability distribution
The shared feedforward network (torso).
- Type:
nnx.Sequential
Linear layer mapping shared features to the policy distribution means.
- Type:
nnx.Linear
Linear layer mapping LSTM features to the policy distribution standard deviations if actor_sigma_head is true, else independent parameter.
- Type:
nnx.Sequential
Linear layer mapping shared features to the value estimate.
- Type:
nnx.Linear
Bijector for constraining the action space.
- Type:
Distrax.bijector:
- class jaxdem.rl.models.ActorCritic(*args: Any, **kwargs: Any)[source][source]#
Bases:
Model,ModuleAn actor-critic model with separate networks for the actor and critic.
Unlike SharedActorCritic, this model uses two independent feedforward networks: - Actor torso: processes observations into features for the policy. - Critic torso: processes observations into features for the value function.
- Parameters:
observation_space (int) – Shape of the observation space (excluding batch dimension).
action_space (ActionSpace) – Shape of the action space.
key (nnx.Rngs) – Random number generator(s) for parameter initialization.
actor_architecture (Sequence[int]) – Sizes of the hidden layers in the actor torso.
critic_architecture (Sequence[int]) – Sizes of the hidden layers in the critic torso.
in_scale (float) – Scaling factor for orthogonal initialization of hidden layers.
actor_scale (float) – Scaling factor for orthogonal initialization of the actor head.
critic_scale (float) – Scaling factor for orthogonal initialization of the critic head.
activation (Callable) – Activation function applied between hidden layers.
action_space – Bijector to constrain the policy probability distribution
- actor_torso#
Feedforward network for the actor.
- Type:
nnx.Sequential
- critic#
Feedforward network for the critic.
- Type:
nnx.Sequential
- actor_mu#
Linear layer mapping actor_torso’s features to the policy distribution means.
- Type:
nnx.Linear
- actor_sigma#
Linear layer mapping LSTM features to the policy distribution standard deviations if actor_sigma_head is true, else independent parameter.
- Type:
nnx.Sequential
- bij#
Bijector for constraining the action space.
- Type:
Distrax.bijector:
- class jaxdem.rl.models.LSTMActorCritic(*args: Any, **kwargs: Any)[source][source]#
Bases:
Model,ModuleA recurrent actor–critic with an MLP encoder and an LSTM torso.
This model encodes observations with a small feed-forward network, passes the features through a single-layer LSTM, and decodes the LSTM hidden state with linear policy/value heads.
Calling modes
- Sequence mode (training): time-major input
xwith shape (T, B, obs_dim)produces a distribution and value for every step: policy outputs(T, B, action_space_size)and values(T, B, 1). The LSTM carry is initialized to zeros for the sequence.
- Sequence mode (training): time-major input
- Single-step mode (evaluation/rollout): input
xwith shape (..., obs_dim)uses and updates a persistent LSTM carry stored on the module (self.h,self.c); outputs have shape(..., action_space_size)and(..., 1). Usereset_carry()to clear state between episodes. Carry needs to be reset every new trajectory.
- Single-step mode (evaluation/rollout): input
- Parameters:
observation_space_size (int) – Flattened observation size (
obs_dim).action_space_size (int) – Number of action dimensions.
key (nnx.Rngs) – Random number generator(s) for parameter initialization.
hidden_features (int) – Width of the encoder output (and LSTM input).
lstm_features (int) – LSTM hidden/state size. Also the feature size consumed by the heads.
activation (Callable) – Activation function applied inside the encoder.
action_space (distrax.Bijector | ActionSpace | None, default=None) – Bijector to constrain the policy probability distribution.
- obs_dim#
Observation dimensionality expected on the last axis of
x.- Type:
int
- lstm_features#
LSTM hidden/state size.
- Type:
int
- encoder#
MLP that maps
obs_dim -> hidden_features.- Type:
nnx.Sequential
- cell#
LSTM cell with
in_features = hidden_featuresandhidden_features = lstm_features.- Type:
rnn.OptimizedLSTMCell
- actor_mu#
Linear layer mapping LSTM features to the policy distribution means.
- Type:
nnx.Linear
- actor_sigma#
Linear layer mapping LSTM features to the policy distribution standard deviations if actor_sigma_head is true, else independent parameter.
- Type:
nnx.Sequential
- critic#
Linear head mapping LSTM features to a scalar value.
- Type:
nnx.Linear
- bij#
Action-space bijector; scalar bijectors are automatically lifted with
Block(ndims=1)for vector actions.- Type:
distrax.Bijector
- h, c
Persistent LSTM carry used by single-step evaluation. Shapes are
(..., lstm_features)and are resized lazily to match the leading batch/agent dimensions.- Type:
nnx.Variable
- reset(shape: Tuple, mask: Array | None = None)[source][source]#
Reset the persistent LSTM carry.
If self.h.value.shape != (*lead_shape, H), allocate fresh zeros once.
- Otherwise, zero in-place:
if mask is None: zero everything
if mask is provided: zero masked entries along axis=0
- Parameters:
shape (tuple[int, ...]) – Shape of the observation (input) tensor.
mask (optional bool array) – Mask per environment (axis=0) to conditionally reset the carry.
- class jaxdem.rl.models.DeepOnetActorCritic(*args: Any, **kwargs: Any)[source][source]#
Bases:
Model,ModuleA DeepOnet-based Actor-Critic model with a Dynamic Weighted Combiner.
Architecture: 1. Trunk (T): MLP encoding Goal + Velocity. 2. Branch (B): MLP encoding Lidar/Sensor data. 3. Weighted Combiner: T gates B features dynamically. 4. Actor/Critic Heads.
- Parameters:
observation_space_size (int) – Total size of the observation space.
action_space_size (int) – Size of the action space.
key (nnx.Rngs) – Random number generator.
trunk_architecture (Sequence[int]) – Hidden layers for the trunk (Goal/Vel).
branch_architecture (Sequence[int]) – Hidden layers for the branch (Lidar features).
combiner_architecture (Sequence[int]) – Hidden layers for the merging network.
critic_architecture (Sequence[int]) – Hidden layers for the critic network (after combiner).
basis_dim (int) – Output size of Trunk and Branch before combination.
Modules