jaxdem.rl.models.MLP#
Implementation of reinforcement learning models based on simple MLPs.
Classes
|
An actor-critic model with separate networks for the actor and critic. |
|
A shared-parameter dense actor-critic model. |
Bases:
Model
A shared-parameter dense actor-critic model.
This model uses a common feedforward network (the “shared torso”) to process observations, and then branches into two separate linear heads: - Actor head: outputs the mean of a Gaussian action distribution. - Critic head: outputs a scalar value estimate of the state.
- Parameters:
observation_space (Sequence[int]) – Shape of the observation space (excluding batch dimension).
action_space (ActionSpace) – Shape of the action space.
key (nnx.Rngs) – Random number generator(s) for parameter initialization.
architecture (Sequence[int]) – Sizes of the hidden layers in the shared network.
in_scale (float) – Scaling factor for orthogonal initialization of the shared network layers.
actor_scale (float) – Scaling factor for orthogonal initialization of the actor head.
critic_scale (float) – Scaling factor for orthogonal initialization of the critic head.
activation (Callable) – Activation function applied between hidden layers.
action_space – Bijector to constrain the policy probability distribution
The shared feedforward network (torso).
- Type:
nnx.Sequential
Linear layer mapping shared features to action means.
- Type:
nnx.Linear
Linear layer mapping shared features to a scalar value estimate.
- Type:
nnx.Linear
Learnable log standard deviation for the Gaussian action distribution.
- Type:
nnx.Param
Bijector for constraining the action space.
- Type:
Distrax.bijector:
- class jaxdem.rl.models.MLP.ActorCritic(*args: Any, **kwargs: Any)[source][source]#
Bases:
Model
,Module
An actor-critic model with separate networks for the actor and critic.
Unlike SharedActorCritic, this model uses two independent feedforward networks: - Actor torso: processes observations into features for the policy. - Critic torso: processes observations into features for the value function.
- Parameters:
observation_space (Sequence[int]) – Shape of the observation space (excluding batch dimension).
action_space (ActionSpace) – Shape of the action space.
key (nnx.Rngs) – Random number generator(s) for parameter initialization.
actor_architecture (Sequence[int]) – Sizes of the hidden layers in the actor torso.
critic_architecture (Sequence[int]) – Sizes of the hidden layers in the critic torso.
in_scale (float) – Scaling factor for orthogonal initialization of hidden layers.
actor_scale (float) – Scaling factor for orthogonal initialization of the actor head.
critic_scale (float) – Scaling factor for orthogonal initialization of the critic head.
activation (Callable) – Activation function applied between hidden layers.
action_space – Bijector to constrain the policy probability distribution
- actor_torso#
Feedforward network for the actor.
- Type:
nnx.Sequential
- critic_torso#
Feedforward network for the critic.
- Type:
nnx.Sequential
- actor#
Linear layer mapping actor features to action means.
- Type:
nnx.Linear
- critic#
Linear layer mapping critic features to a scalar value estimate.
- Type:
nnx.Linear
- log_std[source]#
Learnable log standard deviation for the Gaussian action distribution.
- Type:
nnx.Param
- bij#
Bijector for constraining the action space.
- Type:
Distrax.bijector: