jaxdem.rl.models.MLP#

Implementation of reinforcement learning models based on simple MLPs.

Classes

ActorCritic(*args, **kwargs)

An actor-critic model with separate networks for the actor and critic.

SharedActorCritic(*args, **kwargs)

A shared-parameter dense actor-critic model.

class jaxdem.rl.models.MLP.SharedActorCritic(*args: Any, **kwargs: Any)[source][source]#

Bases: Model

A shared-parameter dense actor-critic model.

This model uses a common feedforward network (the “shared torso”) to process observations, and then branches into two separate linear heads: - Actor head: outputs the mean of a Gaussian action distribution. - Critic head: outputs a scalar value estimate of the state.

Parameters:
  • observation_space (Sequence[int]) – Shape of the observation space (excluding batch dimension).

  • action_space (ActionSpace) – Shape of the action space.

  • key (nnx.Rngs) – Random number generator(s) for parameter initialization.

  • architecture (Sequence[int]) – Sizes of the hidden layers in the shared network.

  • in_scale (float) – Scaling factor for orthogonal initialization of the shared network layers.

  • actor_scale (float) – Scaling factor for orthogonal initialization of the actor head.

  • critic_scale (float) – Scaling factor for orthogonal initialization of the critic head.

  • activation (Callable) – Activation function applied between hidden layers.

  • action_space – Bijector to constrain the policy probability distribution

network#

The shared feedforward network (torso).

Type:

nnx.Sequential

actor#

Linear layer mapping shared features to action means.

Type:

nnx.Linear

critic#

Linear layer mapping shared features to a scalar value estimate.

Type:

nnx.Linear

log_std[source]#

Learnable log standard deviation for the Gaussian action distribution.

Type:

nnx.Param

bij#

Bijector for constraining the action space.

Type:

Distrax.bijector:

property metadata: Dict[source]#
property log_std: Param[source]#
classmethod registry_name() str[source]#
property type_name: str[source]#
class jaxdem.rl.models.MLP.ActorCritic(*args: Any, **kwargs: Any)[source][source]#

Bases: Model, Module

An actor-critic model with separate networks for the actor and critic.

Unlike SharedActorCritic, this model uses two independent feedforward networks: - Actor torso: processes observations into features for the policy. - Critic torso: processes observations into features for the value function.

Parameters:
  • observation_space (Sequence[int]) – Shape of the observation space (excluding batch dimension).

  • action_space (ActionSpace) – Shape of the action space.

  • key (nnx.Rngs) – Random number generator(s) for parameter initialization.

  • actor_architecture (Sequence[int]) – Sizes of the hidden layers in the actor torso.

  • critic_architecture (Sequence[int]) – Sizes of the hidden layers in the critic torso.

  • in_scale (float) – Scaling factor for orthogonal initialization of hidden layers.

  • actor_scale (float) – Scaling factor for orthogonal initialization of the actor head.

  • critic_scale (float) – Scaling factor for orthogonal initialization of the critic head.

  • activation (Callable) – Activation function applied between hidden layers.

  • action_space – Bijector to constrain the policy probability distribution

actor_torso#

Feedforward network for the actor.

Type:

nnx.Sequential

critic_torso#

Feedforward network for the critic.

Type:

nnx.Sequential

actor#

Linear layer mapping actor features to action means.

Type:

nnx.Linear

critic#

Linear layer mapping critic features to a scalar value estimate.

Type:

nnx.Linear

log_std[source]#

Learnable log standard deviation for the Gaussian action distribution.

Type:

nnx.Param

bij#

Bijector for constraining the action space.

Type:

Distrax.bijector:

property metadata: Dict[source]#
property log_std: Param[source]#
classmethod registry_name() str[source]#
property type_name: str[source]#