Source code for jaxdem.rl.models.DeepONet

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM

from __future__ import annotations

import jax
import jax.numpy as jnp
from typing import Tuple, Callable, Sequence, Dict, cast
import math
from functools import partial

from flax import nnx
import distrax

from . import Model
from ..actionSpaces import ActionSpace
from ...utils import encode_callable


[docs] @Model.register("DeepOnetActorCritic") class DeepOnetActorCritic(Model, nnx.Module): """ A DeepOnet-based Actor-Critic model with a Dynamic Weighted Combiner. Architecture: 1. **Trunk (T)**: MLP encoding Goal + Velocity. 2. **Branch (B)**: MLP encoding Lidar/Sensor data. 3. **Weighted Combiner**: T gates B features dynamically. 4. **Actor/Critic Heads**. Parameters ---------- observation_space_size : int Total size of the observation space. action_space_size : int Size of the action space. key : nnx.Rngs Random number generator. trunk_architecture : Sequence[int] Hidden layers for the trunk (Goal/Vel). branch_architecture : Sequence[int] Hidden layers for the branch (Lidar features). combiner_architecture : Sequence[int] Hidden layers for the merging network. critic_architecture : Sequence[int] Hidden layers for the critic network (after combiner). basis_dim : int Output size of Trunk and Branch before combination. """ __slots__ = () def __init__( self, *, observation_space_size: int, action_space_size: int, key: nnx.Rngs, trunk_architecture: Sequence[int] = (64, 64), branch_architecture: Sequence[int] = (64, 64), combiner_architecture: Sequence[int] = (64, 64), critic_architecture: Sequence[int] = (64, 64), basis_dim: int = 64, activation: Callable = nnx.gelu, in_scale: float = math.sqrt(2), actor_scale: float = 1.0, critic_scale: float = 0.01, actor_sigma_head: bool = False, action_space: distrax.Bijector | ActionSpace | None = None, ): self.observation_space_size = int(observation_space_size) self.action_space_size = int(action_space_size) self.trunk_architecture = [int(x) for x in trunk_architecture] self.branch_architecture = [int(x) for x in branch_architecture] self.combiner_architecture = [int(x) for x in combiner_architecture] self.critic_architecture = [int(x) for x in critic_architecture] self.basis_dim = int(basis_dim) self.activation = activation self.in_scale = float(in_scale) self.actor_scale = float(actor_scale) self.critic_scale = float(critic_scale) self.actor_sigma_head = actor_sigma_head # --- 1. Trunk (Goal + Velocity -> 4 inputs) --- trunk_layers = [] trunk_in = 4 for width in self.trunk_architecture: trunk_layers.append( nnx.Linear( in_features=trunk_in, out_features=width, kernel_init=nnx.initializers.orthogonal(self.in_scale), bias_init=nnx.initializers.constant(0.0), rngs=key, ) ) # pyrefly: ignore [bad-argument-type] trunk_layers.append(self.activation) trunk_in = width # Projection to basis_dim trunk_layers.append( nnx.Linear(in_features=trunk_in, out_features=self.basis_dim, rngs=key) ) # pyrefly: ignore [bad-argument-type] trunk_layers.append(self.activation) self.trunk = nnx.Sequential(*trunk_layers) # --- 2. Branch (Lidar -> obs-4 inputs) --- branch_in_len = self.observation_space_size - 4 if branch_in_len <= 0: raise ValueError("Observation space too small for split (need > 4).") # b. MLP layers (Input size is the raw Lidar length) branch_layers = [] current_branch_in = branch_in_len for width in self.branch_architecture: branch_layers.append( nnx.Linear( in_features=current_branch_in, out_features=width, kernel_init=nnx.initializers.orthogonal(self.in_scale), bias_init=nnx.initializers.constant(0.0), rngs=key, ) ) # pyrefly: ignore [bad-argument-type] branch_layers.append(self.activation) current_branch_in = width # c. Projection to basis_dim branch_layers.append( nnx.Linear( in_features=current_branch_in, out_features=self.basis_dim, rngs=key ) ) # pyrefly: ignore [bad-argument-type] branch_layers.append(self.activation) self.branch_mlp = nnx.Sequential(*branch_layers) # --- 3. Dynamic Gating Network (The Weighted Mixer) --- # T (Trunk output) is used to generate the scaling factor G self.gating_network = nnx.Sequential( nnx.Linear( in_features=self.basis_dim, out_features=self.basis_dim, kernel_init=nnx.initializers.orthogonal(self.in_scale), bias_init=nnx.initializers.constant(0.0), rngs=key, ), jax.nn.sigmoid, # Sigmoid ensures the gate values are between 0 and 1 ) # --- 4. Combiner MLP (Takes T + Gated B) --- combiner_layers = [] # Input is Trunk basis + Gated Branch basis combiner_in = self.basis_dim * 2 for width in self.combiner_architecture: combiner_layers.append( nnx.Linear( in_features=combiner_in, out_features=width, kernel_init=nnx.initializers.orthogonal(self.in_scale), bias_init=nnx.initializers.constant(0.0), rngs=key, ) ) combiner_layers.append(self.activation) combiner_in = width self.combiner = nnx.Sequential(*combiner_layers) feature_dim = combiner_in # --- 5. Actor Heads (Attached to Combiner Output) --- self.actor_mu = nnx.Linear( in_features=feature_dim, out_features=self.action_space_size, kernel_init=nnx.initializers.orthogonal(self.actor_scale), bias_init=nnx.initializers.constant(0.0), rngs=key, ) self._log_std = nnx.Param(jnp.zeros((1, self.action_space_size))) self._actor_sigma = nnx.Sequential( nnx.Linear( in_features=feature_dim, out_features=self.action_space_size, kernel_init=nnx.initializers.orthogonal(self.critic_scale), bias_init=nnx.initializers.constant(-1.0), rngs=key, ), jax.nn.softplus, ) if self.actor_sigma_head: self.actor_sigma = lambda x: self._actor_sigma(x) else: self.actor_sigma = lambda x: jnp.exp(self._log_std.value) # --- 6. Critic MLP + Head --- critic_layers = [] critic_in = feature_dim for width in self.critic_architecture: critic_layers.append( nnx.Linear( in_features=critic_in, out_features=width, kernel_init=nnx.initializers.orthogonal(self.in_scale), bias_init=nnx.initializers.constant(0.0), rngs=key, ) ) # pyrefly: ignore [bad-argument-type] critic_layers.append(self.activation) critic_in = width # Final projection to scalar value critic_layers.append( nnx.Linear( in_features=critic_in, out_features=1, kernel_init=nnx.initializers.orthogonal(self.critic_scale), bias_init=nnx.initializers.constant(0.0), rngs=key, ) ) self.critic = nnx.Sequential(*critic_layers) # --- Action Space --- if action_space is None: action_space = ActionSpace.create("Free") bij = cast(distrax.Bijector, action_space) if getattr(bij, "event_ndims_in", 0) == 0: bij = distrax.Block(bij, ndims=1) self.bij = bij @property def metadata(self) -> Dict: """Includes all initialization parameters for model reconstruction.""" return dict( observation_space_size=self.observation_space_size, action_space_size=self.action_space_size, trunk_architecture=self.trunk_architecture, branch_architecture=self.branch_architecture, combiner_architecture=self.combiner_architecture, critic_architecture=self.critic_architecture, basis_dim=self.basis_dim, activation=encode_callable(self.activation), in_scale=self.in_scale, actor_scale=self.actor_scale, critic_scale=self.critic_scale, actor_sigma_head=self.actor_sigma_head, # pyrefly: ignore [missing-attribute] action_space_type=self.bij.type_name, # pyrefly: ignore [missing-attribute] action_space_kws=self.bij.kws, ) @partial(jax.named_call, name="DeepOnetActorCritic.__call__") def __call__( self, x: jax.Array, sequence: bool = False ) -> Tuple[distrax.Distribution, jax.Array]: # 1. Split Inputs trunk_input = x[..., :4] branch_input = x[..., 4:] # Lidar/Sensor data # 2. Process Trunk t_out = self.trunk(trunk_input) # Goal/Velocity features # 3. Process Branch (MLP) # Directly pass the flat branch input to the MLP b_out = self.branch_mlp(branch_input) # 4. Dynamic Weighted Mixer # Generate Gating weights G from the Trunk output gate_weights = self.gating_network(t_out) # Apply Gating: B_gated = B_out * G b_gated = b_out * gate_weights # Combine the original Trunk features (T_out) with the gated Branch features (B_gated) combined = jnp.concatenate([t_out, b_gated], axis=-1) features = self.combiner(combined) # 5. Actor mu = self.actor_mu(features) sigma = self.actor_sigma(features) pi = distrax.MultivariateNormalDiag(mu, sigma) pi = distrax.Transformed(pi, self.bij) # 6. Critic (MLP -> Scalar) val = self.critic(features) # pyrefly: ignore [bad-return] return pi, val