jaxdem.rl.models.DeepONet#

Classes

DeepOnetActorCritic(*args, **kwargs)

A DeepOnet-based Actor-Critic model with a Dynamic Weighted Combiner.

class jaxdem.rl.models.DeepONet.DeepOnetActorCritic(*args: Any, **kwargs: Any)[source][source]#

Bases: Model, Module

A DeepOnet-based Actor-Critic model with a Dynamic Weighted Combiner.

Architecture: 1. Trunk (T): MLP encoding Goal + Velocity. 2. Branch (B): MLP encoding Lidar/Sensor data. 3. Weighted Combiner: T gates B features dynamically. 4. Actor/Critic Heads.

Parameters:
  • observation_space_size (int) – Total size of the observation space.

  • action_space_size (int) – Size of the action space.

  • key (nnx.Rngs) – Random number generator.

  • trunk_architecture (Sequence[int]) – Hidden layers for the trunk (Goal/Vel).

  • branch_architecture (Sequence[int]) – Hidden layers for the branch (Lidar features).

  • combiner_architecture (Sequence[int]) – Hidden layers for the merging network.

  • critic_architecture (Sequence[int]) – Hidden layers for the critic network (after combiner).

  • basis_dim (int) – Output size of Trunk and Branch before combination.

property metadata: Dict[source]#

Includes all initialization parameters for model reconstruction.