Source code for jaxdem.rl.actionSpaces.freeSpace
# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Implementation of identity bijector for free space.
"""
import jax
import jax.numpy as jnp
from typing import Tuple, Optional, Dict
import distrax
from distrax._src.bijectors.bijector import Array
from . import ActionSpace
[docs]
@ActionSpace.register("Free")
class FreeSpace(distrax.Bijector, ActionSpace):
r"""
Identity constraint (no transform).
**Mapping**
.. math::
y = f(x) = x, \qquad x = f^{-1}(y) = y.
**Jacobian**
.. math::
J_f(x) = I,\qquad \log\lvert\det J_f(x)\rvert = 0, \qquad \log\lvert\det J_{f^{-1}}(y)\rvert = 0.
Parameters
----------
-event_ndims_in : int
dimensionality of a *single event* seen by the bijector (defaults to 0 for a scalar transform).
-event_ndims_out : Optional[int]
standard Distrax/TFP bijector flags.
-is_constant_jacobian : bool
standard Distrax/TFP bijector flags.
-is_constant_log_det : bool
standard Distrax/TFP bijector flags.
Note
----------
This bijector is **scalar** (``event_ndims_in = 0``). For vector actions,
needs to be wrap it with ``distrax.Block(bijector, ndims=1)`. Let the model do that for you!
"""
__slots__ = ()
def __init__(
self,
event_ndims_in: int = 0,
event_ndims_out: Optional[int] = None,
is_constant_jacobian: bool = True,
is_constant_log_det: Optional[bool] = True,
):
super().__init__(
event_ndims_in=event_ndims_in,
event_ndims_out=event_ndims_out,
is_constant_jacobian=is_constant_jacobian,
is_constant_log_det=is_constant_log_det,
)
@property
def kws(self) -> Dict:
return dict(
event_ndims_in=self.event_ndims_in,
event_ndims_out=self.event_ndims_out,
is_constant_jacobian=self.is_constant_jacobian,
is_constant_log_det=self.is_constant_log_det,
)
[docs]
def forward_and_log_det(self, x: Array) -> Tuple[Array, jax.Array]:
# log|det J| = 0 for identity; shape matches x for a scalar bijector
return x, jnp.zeros_like(x)
[docs]
def inverse_and_log_det(self, y: Array) -> Tuple[Array, jax.Array]:
# inverse is identity; log|det J_inv| = 0
return y, jnp.zeros_like(y)
[docs]
def same_as(self, other: distrax.Bijector) -> bool:
return type(other) is FreeSpace # pylint: disable=unidiomatic-typecheck
__all__ = ["FreeSpace"]