Source code for jaxdem.rl.actionSpaces

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Interface for defining bijectors used to constraint the policy probability distribution.
"""

from __future__ import annotations
from typing import Dict

from ...factory import Factory


[docs] class ActionSpace(Factory): """ Registry/namespace for action-space **constraints** implemented as `distrax.Bijector`s. These bijectors are intended to be wrapped around a base policy distribution (e.g., `MultivariateNormalDiag`) via `distrax.Transformed`, so that sampling and log-probabilities are correctly adjusted using the bijector’s `forward_and_log_det` / `inverse_and_log_det` methods. See Distrax/TFP bijector interface for details on shape semantics and `event_ndims_in/out`. Example ------- To define a custom action space, inherit from :class:`distrax.Bijector` and :class:`ActionSpace` and implement its abstract methods: >>> @ActionSpace.register("myCustomActionSpace") >>> class MyCustomActionSpace(distrax.Bijector, ActionSpace): ... """ __slots__ = () @property def kws(self) -> Dict: return dict()
from .freeSpace import FreeSpace from .boxSpace import BoxSpace from .maxNormSpace import MaxNormSpace __all__ = ["ActionSpace", "FreeSpace", "BoxSpace", "MaxNormSpace"]