jaxdem.rl.models.DeepONet#
Classes
|
A DeepOnet-based Actor-Critic model with a Dynamic Weighted Combiner. |
- class jaxdem.rl.models.DeepONet.DeepOnetActorCritic(*args: Any, **kwargs: Any)[source][source]#
Bases:
Model,ModuleA DeepOnet-based Actor-Critic model with a Dynamic Weighted Combiner.
Architecture: 1. Trunk (T): MLP encoding Goal + Velocity. 2. Branch (B): MLP encoding Lidar/Sensor data. 3. Weighted Combiner: T gates B features dynamically. 4. Actor/Critic Heads.
- Parameters:
observation_space_size (int) – Total size of the observation space.
action_space_size (int) – Size of the action space.
key (nnx.Rngs) – Random number generator.
trunk_architecture (Sequence[int]) – Hidden layers for the trunk (Goal/Vel).
branch_architecture (Sequence[int]) – Hidden layers for the branch (Lidar features).
combiner_architecture (Sequence[int]) – Hidden layers for the merging network.
critic_architecture (Sequence[int]) – Hidden layers for the critic network (after combiner).
basis_dim (int) – Output size of Trunk and Branch before combination.