Source code for jaxdem.rl.models
# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Interface for defining reinforcement learning models.
"""
from __future__ import annotations
import jax
from typing import Tuple, Dict
from abc import ABC, abstractmethod
from flax import nnx
import distrax
from ...factory import Factory
[docs]
class Model(Factory, nnx.Module, ABC):
"""
The base interface for defining reinforcement learning models. Acts as a name space.
Models map observations to an action distribution and a value estimate.
Example
-------
To define a custom model, inherit from :class:`Model` and implement its abstract methods:
>>> @Model.register("myCustomModel")
>>> class MyCustomModel(Model):
...
"""
__slots__ = ()
@property
def log_std(self) -> nnx.Param:
return nnx.Param(0)
@property
def metadata(self) -> Dict:
return {}
[docs]
def reset(self, shape: Tuple, mask: jax.Array | None = None):
"""
Reset the persistent LSTM carry.
Parameters
-----------
lead_shape : tuple[int, ...]
Leading dims for the carry, e.g. (num_envs, num_agents).
mask : optional bool array
True where to reset entries. Shape (num_envs)
"""
...
@abstractmethod
def __call__(
self, x: jax.Array, sequence: bool = True
) -> Tuple[distrax.Distribution, jax.Array]:
"""
Forward pass of the model.
Parameters
----------
x : ArrayLike: jax.Array
Batch of observations.
Returns
-------
tuple[Distribution, jax.Array]
- A `distrax.MultivariateNormalDiag` distribution over actions.
- A value estimate tensor of shape ``(batch, 1)``.
"""
raise NotImplementedError
from .MLP import SharedActorCritic, ActorCritic
from .LSTM import LSTMActorCritic
__all__ = ["Model", "SharedActorCritic", "ActorCritic", "LSTMActorCritic"]