jaxdem.rl.models.MLP#
Implementation of reinforcement learning models based on simple MLPs.
Classes
|
An actor-critic model with separate networks for the actor and critic. |
|
A shared-parameter dense actor-critic model. |
- class jaxdem.rl.models.MLP.ActorCritic(*args: Any, **kwargs: Any)[source]#
Bases:
Model,ModuleAn actor-critic model with separate networks for the actor and critic.
Unlike SharedActorCritic, this model uses two independent feedforward networks: - Actor torso: processes observations into features for the policy. - Critic torso: processes observations into features for the value function.
- Parameters:
observation_space (int) – Shape of the observation space (excluding batch dimension).
action_space (ActionSpace) – Shape of the action space (for continuous) or number of discrete actions.
key (nnx.Rngs) – Random number generator(s) for parameter initialization.
actor_architecture (Sequence[int]) – Sizes of the hidden layers in the actor torso.
critic_architecture (Sequence[int]) – Sizes of the hidden layers in the critic torso.
in_scale (float) – Scaling factor for orthogonal initialization of hidden layers.
actor_scale (float) – Scaling factor for orthogonal initialization of the actor head.
critic_scale (float) – Scaling factor for orthogonal initialization of the critic head.
activation (Callable) – Activation function applied between hidden layers.
action_space – Bijector to constrain the policy probability distribution (continuous only).
discrete (bool) – If True, use a categorical distribution for discrete actions. If False (default), use a Gaussian distribution for continuous actions.
- actor_torso#
Feedforward network for the actor.
- Type:
nnx.Sequential
- critic#
Feedforward network for the critic.
- Type:
nnx.Sequential
- actor_mu#
Linear layer mapping actor_torso’s features to the policy distribution means (continuous) or logits (discrete).
- Type:
nnx.Linear
- actor_sigma#
Linear layer mapping actor torso features to the policy distribution standard deviations if actor_sigma_head is true, else independent parameter. Only used for continuous actions.
- Type:
nnx.Sequential
- bij#
Bijector for constraining the action space (continuous only).
- Type:
distrax.Bijector
- discrete#
Whether this model uses discrete actions.
- Type:
bool
Bases:
ModelA shared-parameter dense actor-critic model.
This model uses a common feedforward network (the “shared torso”) to process observations, and then branches into two separate linear heads: - Actor head: outputs the mean of a Gaussian action distribution (continuous)
or logits for a categorical distribution (discrete).
Critic head: outputs a scalar value estimate of the state.
- Parameters:
observation_space (int) – Shape of the observation space (excluding batch dimension).
action_space (ActionSpace) – Shape of the action space (for continuous) or number of discrete actions.
key (nnx.Rngs) – Random number generator(s) for parameter initialization.
architecture (Sequence[int]) – Sizes of the hidden layers in the shared network.
in_scale (float) – Scaling factor for orthogonal initialization of the shared network layers.
actor_scale (float) – Scaling factor for orthogonal initialization of the actor head.
critic_scale (float) – Scaling factor for orthogonal initialization of the critic head.
activation (Callable) – JIT-compatible activation function applied between hidden layers.
action_space – Bijector to constrain the policy probability distribution (continuous only).
discrete (bool) – If True, use a categorical distribution for discrete actions. If False (default), use a Gaussian distribution for continuous actions.
The shared feedforward network (torso).
- Type:
nnx.Sequential
Linear layer mapping shared features to the policy distribution means (continuous) or logits (discrete).
- Type:
nnx.Linear
Linear layer mapping shared features to the policy distribution standard deviations if actor_sigma_head is true, else independent parameter. Only used for continuous actions.
- Type:
nnx.Sequential
Linear layer mapping shared features to the value estimate.
- Type:
nnx.Linear
Bijector for constraining the action space (continuous only).
- Type:
distrax.Bijector
Whether this model uses discrete actions.
- Type:
bool