Source code for jaxdem.rl.models.MLP

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Implementation of reinforcement learning models based on simple MLPs.
"""

from __future__ import annotations

import jax
import jax.numpy as jnp

from typing import Tuple, Callable, Sequence, Dict, cast
import math

from flax import nnx
import distrax

from . import Model
from ..actionSpaces import ActionSpace
from ...utils import encode_callable


[docs] @Model.register("SharedActorCritic") class SharedActorCritic(Model): """ A shared-parameter dense actor-critic model. This model uses a common feedforward network (the "shared torso") to process observations, and then branches into two separate linear heads: - **Actor head**: outputs the mean of a Gaussian action distribution. - **Critic head**: outputs a scalar value estimate of the state. Parameters ---------- observation_space : Sequence[int] Shape of the observation space (excluding batch dimension). action_space : Sequence[int] Shape of the action space. key : nnx.Rngs Random number generator(s) for parameter initialization. architecture : Sequence[int] Sizes of the hidden layers in the shared network. in_scale : float Scaling factor for orthogonal initialization of the shared network layers. actor_scale : float Scaling factor for orthogonal initialization of the actor head. critic_scale : float Scaling factor for orthogonal initialization of the critic head. activation : Callable Activation function applied between hidden layers. action_space: ActionSpace Bijector to constrain the policy probability distribution Attributes ---------- network : nnx.Sequential The shared feedforward network (torso). actor : nnx.Linear Linear layer mapping shared features to action means. critic : nnx.Linear Linear layer mapping shared features to a scalar value estimate. log_std : nnx.Param Learnable log standard deviation for the Gaussian action distribution. bij: Distrax.bijector: Bijector for constraining the action space. """ __slots__ = () def __init__( self, *, observation_space_size: int, action_space_size: int, key: nnx.Rngs, architecture: Sequence[int] = [32, 32], in_scale: float = math.sqrt(2), actor_scale: float = 1.0, critic_scale: float = 0.01, activation: Callable = nnx.gelu, action_space: ActionSpace | None = None, ): self.observation_space_size = int(observation_space_size) self.action_space_size = int(action_space_size) self.architecture = [int(x) for x in architecture] in_scale = float(in_scale) actor_scale = float(actor_scale) critic_scale = float(critic_scale) self.activation = activation layers = [] input_dim = self.observation_space_size out_dim = action_space_size for output_dim in self.architecture: layers.append( nnx.Linear( in_features=input_dim, out_features=output_dim, kernel_init=nnx.initializers.orthogonal(in_scale), bias_init=nnx.initializers.constant(0.0), rngs=key, ) ) layers.append(self.activation) input_dim = output_dim self.network = nnx.Sequential(*layers) self.actor = nnx.Linear( in_features=input_dim, out_features=out_dim, kernel_init=nnx.initializers.orthogonal(actor_scale), bias_init=nnx.initializers.constant(0.0), rngs=key, ) self.critic = nnx.Linear( in_features=input_dim, out_features=1, kernel_init=nnx.initializers.orthogonal(critic_scale), bias_init=nnx.initializers.constant(0.0), rngs=key, ) self._log_std = nnx.Param(jnp.zeros((1, out_dim))) if action_space is None: action_space = ActionSpace.create("Free") # Check if bijector is scalar bij = cast(distrax.Bijector, action_space) if getattr(bij, "event_ndims_in", 0) == 0: bij = distrax.Block(bij, ndims=1) self.bij = bij @property def metadata(self) -> Dict: return dict( observation_space_size=self.observation_space_size, action_space_size=self.action_space_size, architecture=self.architecture, activation=encode_callable(self.activation), action_space_type=self.bij.type_name, action_space_kws=self.bij.kws, ) @property def log_std(self) -> nnx.Param: return self._log_std def __call__( self, x: jax.Array, sequence: bool = True ) -> Tuple[distrax.Distribution, jax.Array]: """ Forward pass of the shared actor-critic model. Parameters ---------- x : ArrayLike: jax.Array Batch of observations with shape ``(batch, *flatten(observation_space))``. Returns ------- tuple[Distribution, jax.Array] - A `distrax.MultivariateNormalDiag` distribution over actions. - A value estimate tensor """ x = self.network(x) pi = distrax.MultivariateNormalDiag(self.actor(x), jnp.exp(self.log_std.value)) pi = distrax.Transformed(pi, self.bij) return pi, self.critic(x)
[docs] @Model.register("ActorCritic") class ActorCritic(Model, nnx.Module): """ An actor-critic model with separate networks for the actor and critic. Unlike `SharedActorCritic`, this model uses two independent feedforward networks: - **Actor torso**: processes observations into features for the policy. - **Critic torso**: processes observations into features for the value function. Parameters ---------- observation_space : Sequence[int] Shape of the observation space (excluding batch dimension). action_space : Sequence[int] Shape of the action space. key : nnx.Rngs Random number generator(s) for parameter initialization. actor_architecture : Sequence[int] Sizes of the hidden layers in the actor torso. critic_architecture : Sequence[int] Sizes of the hidden layers in the critic torso. in_scale : float Scaling factor for orthogonal initialization of hidden layers. actor_scale : float Scaling factor for orthogonal initialization of the actor head. critic_scale : float Scaling factor for orthogonal initialization of the critic head. activation : Callable Activation function applied between hidden layers. action_space: ActionSpace Bijector to constrain the policy probability distribution Attributes ---------- actor_torso : nnx.Sequential Feedforward network for the actor. critic_torso : nnx.Sequential Feedforward network for the critic. actor : nnx.Linear Linear layer mapping actor features to action means. critic : nnx.Linear Linear layer mapping critic features to a scalar value estimate. log_std : nnx.Param Learnable log standard deviation for the Gaussian action distribution. bij: Distrax.bijector: Bijector for constraining the action space. """ __slots__ = () def __init__( self, *, observation_space_size: int, action_space_size: int, key: nnx.Rngs, actor_architecture: Sequence[int] = [32, 32], critic_architecture: Sequence[int] = [32, 32], in_scale: float = math.sqrt(2), actor_scale: float = 1.0, critic_scale: float = 0.01, activation: Callable = nnx.gelu, action_space: distrax.Bijector | ActionSpace | None = None, ): self.observation_space_size = int(observation_space_size) self.action_space_size = int(action_space_size) self.actor_architecture = [int(x) for x in actor_architecture] self.critic_architecture = [int(x) for x in critic_architecture] in_scale = float(in_scale) actor_scale = float(actor_scale) critic_scale = float(critic_scale) self.activation = activation input_dim = observation_space_size out_dim = action_space_size # Build actor torso actor_layers = [] actor_in = input_dim for output_dim in actor_architecture: actor_layers.append( nnx.Linear( in_features=actor_in, out_features=output_dim, kernel_init=nnx.initializers.orthogonal(in_scale), bias_init=nnx.initializers.constant(0.0), rngs=key, ) ) actor_layers.append(activation) actor_in = output_dim self.actor_torso = nnx.Sequential(*actor_layers) # Build critic torso critic_layers = [] critic_in = input_dim for output_dim in critic_architecture: critic_layers.append( nnx.Linear( in_features=critic_in, out_features=output_dim, kernel_init=nnx.initializers.orthogonal(in_scale), bias_init=nnx.initializers.constant(0.0), rngs=key, ) ) critic_layers.append(activation) critic_in = output_dim self.critic_torso = nnx.Sequential(*critic_layers) # Actor head self.actor = nnx.Linear( in_features=actor_in, out_features=out_dim, kernel_init=nnx.initializers.orthogonal(actor_scale), bias_init=nnx.initializers.constant(0.0), rngs=key, ) # Critic head self.critic = nnx.Linear( in_features=critic_in, out_features=1, kernel_init=nnx.initializers.orthogonal(critic_scale), bias_init=nnx.initializers.constant(0.0), rngs=key, ) # Global log std for Gaussian policy self._log_std = nnx.Param(jnp.zeros((1, out_dim))) if action_space is None: action_space = ActionSpace.create("Free") # Check if bijector is scalar bij = cast(distrax.Bijector, action_space) if getattr(bij, "event_ndims_in", 0) == 0: bij = distrax.Block(bij, ndims=1) self.bij = bij @property def metadata(self) -> Dict: return dict( observation_space_size=self.observation_space_size, action_space_size=self.action_space_size, actor_architecture=self.actor_architecture, critic_architecture=self.critic_architecture, activation=encode_callable(self.activation), action_space_type=self.bij.type_name, action_space_kws=self.bij.kws, ) @property def log_std(self) -> nnx.Param: return self._log_std def __call__( self, x: jax.Array, sequence: bool = True ) -> Tuple[distrax.Distribution, jax.Array]: """ Forward pass of the actor-critic model with separate torsos. Parameters ---------- x : ArrayLike Batch of observations with shape ``(batch, *flatten(observation_space))``. Returns ------- tuple[Distribution, jax.Array] - A `distrax.MultivariateNormalDiag` distribution over actions. - A value estimate tensor of shape ``(batch, 1)``. """ actor_features = self.actor_torso(x) critic_features = self.critic_torso(x) pi = distrax.MultivariateNormalDiag( self.actor(actor_features), jnp.exp(self.log_std.value), ) pi = distrax.Transformed(pi, self.bij) return pi, self.critic(critic_features)
__all__ = ["SharedActorCritic", "ActorCritic"]