jaxdem.rl.actionSpaces.freeSpace#

Implementation of identity bijector for free space.

Classes

FreeSpace(*args, **kwargs)

Identity constraint (no transform).

class jaxdem.rl.actionSpaces.freeSpace.FreeSpace(*args, **kwargs)[source]#

Bases: Bijector, ActionSpace

Identity constraint (no transform).

Mapping

\[y = f(x) = x, \qquad x = f^{-1}(y) = y.\]

Jacobian

\[J_f(x) = I,\qquad \log\lvert\det J_f(x)\rvert = 0, \qquad \log\lvert\det J_{f^{-1}}(y)\rvert = 0.\]
Parameters:
  • -event_ndims_in (int) – dimensionality of a single event seen by the bijector (defaults to 0 for a scalar transform).

  • -event_ndims_out (Optional[int]) – standard Distrax/TFP bijector flags.

  • -is_constant_jacobian (bool) – standard Distrax/TFP bijector flags.

  • -is_constant_log_det (bool) – standard Distrax/TFP bijector flags.

Note

This bijector is scalar (event_ndims_in = 0). For vector actions, it needs to be wrapped with distrax.Block(bijector, ndims=1). Let the model do that for you!

property kws: dict[str, Any][source]#
forward_and_log_det(x: Array | ndarray | bool | number) tuple[Array | ndarray | bool | number, Array][source]#

Computes y = f(x) and log|det J(f)(x)|.

inverse_and_log_det(y: Array | ndarray | bool | number) tuple[Array | ndarray | bool | number, Array][source]#

Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.

log_det_expectation(mean: Array, std: Array) Array[source]#

Compute \(\mathbb{E}_{X}[\log|\det J_f(X)|]\) where \(X \sim \mathcal{N}(\text{mean}, \text{diag}(\text{std}^2))\).

Subclasses should override this to enable Transformed.entropy().

same_as(other: Bijector) bool[source]#

Returns True if this bijector is guaranteed to be the same as other.