jaxdem.rl.models.MLP#

Implementation of reinforcement learning models based on simple MLPs.

Classes

ActorCritic(*args, **kwargs)

An actor-critic model with separate networks for the actor and critic.

SharedActorCritic(*args, **kwargs)

A shared-parameter dense actor-critic model.

class jaxdem.rl.models.MLP.ActorCritic(*args: Any, **kwargs: Any)[source]#

Bases: Model, Module

An actor-critic model with separate networks for the actor and critic.

Unlike SharedActorCritic, this model uses two independent feedforward networks: - Actor torso: processes observations into features for the policy. - Critic torso: processes observations into features for the value function.

Parameters:
  • observation_space (int) – Shape of the observation space (excluding batch dimension).

  • action_space (ActionSpace) – Shape of the action space (for continuous) or number of discrete actions.

  • key (nnx.Rngs) – Random number generator(s) for parameter initialization.

  • actor_architecture (Sequence[int]) – Sizes of the hidden layers in the actor torso.

  • critic_architecture (Sequence[int]) – Sizes of the hidden layers in the critic torso.

  • in_scale (float) – Scaling factor for orthogonal initialization of hidden layers.

  • actor_scale (float) – Scaling factor for orthogonal initialization of the actor head.

  • critic_scale (float) – Scaling factor for orthogonal initialization of the critic head.

  • activation (Callable) – Activation function applied between hidden layers.

  • action_space – Bijector to constrain the policy probability distribution (continuous only).

  • discrete (bool) – If True, use a categorical distribution for discrete actions. If False (default), use a Gaussian distribution for continuous actions.

actor_torso#

Feedforward network for the actor.

Type:

nnx.Sequential

critic#

Feedforward network for the critic.

Type:

nnx.Sequential

actor_mu#

Linear layer mapping actor_torso’s features to the policy distribution means (continuous) or logits (discrete).

Type:

nnx.Linear

actor_sigma#

Linear layer mapping actor torso features to the policy distribution standard deviations if actor_sigma_head is true, else independent parameter. Only used for continuous actions.

Type:

nnx.Sequential

bij#

Bijector for constraining the action space (continuous only).

Type:

distrax.Bijector

discrete#

Whether this model uses discrete actions.

Type:

bool

property metadata: dict[str, Any][source]#
class jaxdem.rl.models.MLP.SharedActorCritic(*args: Any, **kwargs: Any)[source]#

Bases: Model

A shared-parameter dense actor-critic model.

This model uses a common feedforward network (the “shared torso”) to process observations, and then branches into two separate linear heads: - Actor head: outputs the mean of a Gaussian action distribution (continuous)

or logits for a categorical distribution (discrete).

  • Critic head: outputs a scalar value estimate of the state.

Parameters:
  • observation_space (int) – Shape of the observation space (excluding batch dimension).

  • action_space (ActionSpace) – Shape of the action space (for continuous) or number of discrete actions.

  • key (nnx.Rngs) – Random number generator(s) for parameter initialization.

  • architecture (Sequence[int]) – Sizes of the hidden layers in the shared network.

  • in_scale (float) – Scaling factor for orthogonal initialization of the shared network layers.

  • actor_scale (float) – Scaling factor for orthogonal initialization of the actor head.

  • critic_scale (float) – Scaling factor for orthogonal initialization of the critic head.

  • activation (Callable) – JIT-compatible activation function applied between hidden layers.

  • action_space – Bijector to constrain the policy probability distribution (continuous only).

  • discrete (bool) – If True, use a categorical distribution for discrete actions. If False (default), use a Gaussian distribution for continuous actions.

network#

The shared feedforward network (torso).

Type:

nnx.Sequential

actor_mu#

Linear layer mapping shared features to the policy distribution means (continuous) or logits (discrete).

Type:

nnx.Linear

actor_sigma#

Linear layer mapping shared features to the policy distribution standard deviations if actor_sigma_head is true, else independent parameter. Only used for continuous actions.

Type:

nnx.Sequential

critic#

Linear layer mapping shared features to the value estimate.

Type:

nnx.Linear

bij#

Bijector for constraining the action space (continuous only).

Type:

distrax.Bijector

discrete#

Whether this model uses discrete actions.

Type:

bool

property metadata: dict[str, Any][source]#