jaxdem.rl.models.MLP#
Implementation of reinforcement learning models based on simple MLPs.
Classes
|
An actor-critic model with separate networks for the actor and critic. |
|
A shared-parameter dense actor-critic model. |
Bases:
ModelA shared-parameter dense actor-critic model.
This model uses a common feedforward network (the “shared torso”) to process observations, and then branches into two separate linear heads: - Actor head: outputs the mean of a Gaussian action distribution. - Critic head: outputs a scalar value estimate of the state.
- Parameters:
observation_space (int) – Shape of the observation space (excluding batch dimension).
action_space (ActionSpace) – Shape of the action space.
key (nnx.Rngs) – Random number generator(s) for parameter initialization.
architecture (Sequence[int]) – Sizes of the hidden layers in the shared network.
in_scale (float) – Scaling factor for orthogonal initialization of the shared network layers.
actor_scale (float) – Scaling factor for orthogonal initialization of the actor head.
critic_scale (float) – Scaling factor for orthogonal initialization of the critic head.
activation (Callable) – Activation Jit compatible function applied between hidden layers.
action_space – Bijector to constrain the policy probability distribution
The shared feedforward network (torso).
- Type:
nnx.Sequential
Linear layer mapping shared features to the policy distribution means.
- Type:
nnx.Linear
Linear layer mapping LSTM features to the policy distribution standard deviations if actor_sigma_head is true, else independent parameter.
- Type:
nnx.Sequential
Linear layer mapping shared features to the value estimate.
- Type:
nnx.Linear
Bijector for constraining the action space.
- Type:
Distrax.bijector:
- class jaxdem.rl.models.MLP.ActorCritic(*args: Any, **kwargs: Any)[source][source]#
Bases:
Model,ModuleAn actor-critic model with separate networks for the actor and critic.
Unlike SharedActorCritic, this model uses two independent feedforward networks: - Actor torso: processes observations into features for the policy. - Critic torso: processes observations into features for the value function.
- Parameters:
observation_space (int) – Shape of the observation space (excluding batch dimension).
action_space (ActionSpace) – Shape of the action space.
key (nnx.Rngs) – Random number generator(s) for parameter initialization.
actor_architecture (Sequence[int]) – Sizes of the hidden layers in the actor torso.
critic_architecture (Sequence[int]) – Sizes of the hidden layers in the critic torso.
in_scale (float) – Scaling factor for orthogonal initialization of hidden layers.
actor_scale (float) – Scaling factor for orthogonal initialization of the actor head.
critic_scale (float) – Scaling factor for orthogonal initialization of the critic head.
activation (Callable) – Activation function applied between hidden layers.
action_space – Bijector to constrain the policy probability distribution
- actor_torso#
Feedforward network for the actor.
- Type:
nnx.Sequential
- critic#
Feedforward network for the critic.
- Type:
nnx.Sequential
- actor_mu#
Linear layer mapping actor_torso’s features to the policy distribution means.
- Type:
nnx.Linear
- actor_sigma#
Linear layer mapping LSTM features to the policy distribution standard deviations if actor_sigma_head is true, else independent parameter.
- Type:
nnx.Sequential
- bij#
Bijector for constraining the action space.
- Type:
Distrax.bijector: