jaxdem.rl.actionSpaces#

Interface for defining bijectors used to constrain the policy probability distribution.

Classes

ActionSpace()

Registry/namespace for action-space constraints implemented as distrax.Bijector objects.

Transformed(*args, **kwargs)

`distrax.Transformed` with analytical entropy support.

class jaxdem.rl.actionSpaces.ActionSpace#

Bases: Factory

Registry/namespace for action-space constraints implemented as distrax.Bijector objects.

These bijectors are intended to be wrapped around a base policy distribution (e.g., MultivariateNormalDiag) via distrax.Transformed, so that sampling and log-probabilities are correctly adjusted using the bijector’s ‘forward_and_log_det’ / ‘inverse_and_log_det’ methods. See Distrax/TFP bijector interface for details on shape semantics and ‘event_ndims_in/out’.

Example:#

To define a custom action space, inherit from distrax.Bijector and ActionSpace and implement its abstract methods:

>>> @ActionSpace.register("myCustomActionSpace")
>>> class MyCustomActionSpace(distrax.Bijector, ActionSpace):
        ...
property kws: dict[str, Any][source]#
log_det_expectation(mean: Array, std: Array) Array[source]#

Compute \(\mathbb{E}_{X}[\log|\det J_f(X)|]\) where \(X \sim \mathcal{N}(\text{mean}, \text{diag}(\text{std}^2))\).

Subclasses should override this to enable Transformed.entropy().

class jaxdem.rl.actionSpaces.BoxSpace(*args, **kwargs)[source]#

Bases: Bijector, ActionSpace

Elementwise box constraint implemented with a scaled tanh.

Mapping (componentwise)

\[y_i \;=\; c_i + h_i\,\tanh\!\left(\frac{x_i}{w}\right), \qquad c_i=\tfrac{1}{2}(x_{\min,i}+x_{\max,i}), \quad h_i=\tfrac{1-\varepsilon}{2}(x_{\max,i}-x_{\min,i}),\]

with width parameter (\(w>0\)) and small (\(\epsilon>0\)) for numerical safety.

Jacobian (componentwise) For each component,

\[\frac{\partial y_i}{\partial x_i} = \frac{h_i}{w} sech^2 \left(\frac{x_i}{w}\right), \qquad \log\left| \frac{\partial y_i}{\partial x_i} \right| = \log h_i - \log w + \log\!\big(sech^2(\frac{x_i}{w})\big).\]

Using the stable identity \(\log(sech^2 z)=2 [\log 2 - z - softplus(-2z)]\), which we apply for good numerical behavior.

Parameters:
  • -x_min (jax.Array) – Elementwise lower bounds of the distribution.

  • -x_max (jax.Array) – Elementwise upper bounds of the distribution. Must satisfy x_max > x_min elementwise.

  • -width (float) – slope control.

  • -eps (float) – Small offset to avoid arctanh divergence close to bounds.

  • -event_ndims_in (int) – dimensionality of a single event seen by the bijector (defaults to 0 for a scalar transform).

  • -event_ndims_out (Optional[int]) – standard Distrax/TFP bijector flags.

  • -is_constant_jacobian (bool) – standard Distrax/TFP bijector flags.

  • -is_constant_log_det (bool) – standard Distrax/TFP bijector flags.

Note

This bijector is scalar (event_ndims_in = 0). For vector actions, it needs to be wrapped with distrax.Block(bijector, ndims=1). Let the model do that for you!

property kws: dict[str, Any][source]#
static sec2_log(x: Array) Array[source]#
forward_log_det_jacobian(x: Array | ndarray | bool | number) Array[source]#

Computes log|det J(f)(x)|. log|dy/dx| = log|half| + log(sech^2 x) Stable log(sech^2 x) = 2*(log(2) - x - softplus(-2x)).

forward_and_log_det(x: Array | ndarray | bool | number) tuple[Array, Array][source]#

Computes y = f(x) and log|det J(f)(x)|.

inverse_and_log_det(y: Array | ndarray | bool | number) tuple[Array, Array][source]#

Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.

same_as(other: Bijector) bool[source]#

Returns True if this bijector is guaranteed to be the same as other.

log_det_expectation(mean: Array, std: Array) Array[source]#

\(\mathbb{E}_X[\sum_i \log|dJ_i/dx_i|]\) via 1-D Gauss-Hermite quadrature (componentwise separable).

class jaxdem.rl.actionSpaces.FreeSpace(*args, **kwargs)[source]#

Bases: Bijector, ActionSpace

Identity constraint (no transform).

Mapping

\[y = f(x) = x, \qquad x = f^{-1}(y) = y.\]

Jacobian

\[J_f(x) = I,\qquad \log\lvert\det J_f(x)\rvert = 0, \qquad \log\lvert\det J_{f^{-1}}(y)\rvert = 0.\]
Parameters:
  • -event_ndims_in (int) – dimensionality of a single event seen by the bijector (defaults to 0 for a scalar transform).

  • -event_ndims_out (Optional[int]) – standard Distrax/TFP bijector flags.

  • -is_constant_jacobian (bool) – standard Distrax/TFP bijector flags.

  • -is_constant_log_det (bool) – standard Distrax/TFP bijector flags.

Note

This bijector is scalar (event_ndims_in = 0). For vector actions, it needs to be wrapped with distrax.Block(bijector, ndims=1). Let the model do that for you!

property kws: dict[str, Any][source]#
forward_and_log_det(x: Array | ndarray | bool | number) tuple[Array | ndarray | bool | number, Array][source]#

Computes y = f(x) and log|det J(f)(x)|.

inverse_and_log_det(y: Array | ndarray | bool | number) tuple[Array | ndarray | bool | number, Array][source]#

Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.

log_det_expectation(mean: Array, std: Array) Array[source]#

Compute \(\mathbb{E}_{X}[\log|\det J_f(X)|]\) where \(X \sim \mathcal{N}(\text{mean}, \text{diag}(\text{std}^2))\).

Subclasses should override this to enable Transformed.entropy().

same_as(other: Bijector) bool[source]#

Returns True if this bijector is guaranteed to be the same as other.

class jaxdem.rl.actionSpaces.MaxNormSpace(*args, **kwargs)[source]#

Bases: Bijector, ActionSpace

Radial max-norm constraint for vector actions: scales the radius with a tanh squashing while preserving direction.

Mapping (vector case, \(x \in \mathbb{R}^d\)

\[\begin{split}r = \lVert \vec{x} \rVert_2,\qquad \hat{u} = \begin{cases} \frac{\vec{x}}{r}, & r>0,\\[4pt] 0, & r=0, \end{cases} \qquad y = s \tanh(r) \hat{u}, \quad s = (1-\epsilon) \texttt{max_norm}.\end{split}\]

Equivalently, \(y = b(r)\,x\) with \(b(r)= s\,\tanh(r)/r\) for \(r>0\).

Jacobian determinant

For an isotropic radial map \(f(x)=b(r)\) with \(x \in \mathbb{R}^d\), the Jacobian eigenvalues are \(b\) (multiplicity d-1) on the tangent subspace and \(b + r\,b'(r)\) on the radial direction, hence

\[\bigl|\det J_f(x)\bigr| = b(r)^{\,d-1}\,\bigl(b(r)+r\,b'(r)\bigr) = s^d \left(\frac{\tanh r}{r}\right)^{\!d-1} sech^2 r.\]

Therefore

\[\log\lvert\det J_f(x)\rvert = d\log s + (d-1)\bigl(\log\tanh r - \log r\bigr) + \log( sech^2 r),\]

We use the stable identity \(\log(sech^2 z)=2 [\log 2 - z - softplus(-2z)]\), which we apply for good numerical behavior.

Near \(r\approx 0\), we use the second-order expansion

\[\log\lvert\det J_f(x)\rvert \approx d\log s - \tfrac{2}{3} r^2\]

to avoid division by \(r\).

Parameters:
  • -max_norm (float) – target radius (s) after squashing (default 1.0). We actually use (s=(1-varepsilon),texttt{max_norm}) to avoid the exact boundary.

  • -eps (float) – numerical safety margin used near (r=0) and (rtoinfty).

  • -event_ndims_in (int) – dimensionality of a single event seen by the bijector (defaults to 0 for a scalar transform).

  • -event_ndims_out (Optional[int]) – standard Distrax/TFP bijector flags.

  • -is_constant_jacobian (bool) – standard Distrax/TFP bijector flags.

  • -is_constant_log_det (bool) – standard Distrax/TFP bijector flags.

Note

This bijector is vector-valued with event_ndims_in = 1 (i.e., it operates on length-(d) action vectors as a single event). Do not wrap it in Block unless you intend to apply it independently to multiple last-axis blocks.

property kws: dict[str, Any][source]#
static sec2_log(r: Array) Array[source]#
forward_log_det_jacobian(x: Array | ndarray | bool | number) Array[source]#

Computes log|det J(f)(x)|.

forward_and_log_det(x: Array | ndarray | bool | number) tuple[Array, Array][source]#

Computes y = f(x) and log|det J(f)(x)|.

inverse_and_log_det(y: Array | ndarray | bool | number) tuple[Array, Array][source]#

Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.

same_as(other: Bijector) bool[source]#

Returns True if this bijector is guaranteed to be the same as other.

log_det_expectation(mean: Array, std: Array) Array[source]#

\(\mathbb{E}_X[\log|\det J_f(X)|]\) via tensor-product Gauss-Hermite quadrature in d dimensions.

class jaxdem.rl.actionSpaces.Transformed(*args, **kwargs)[source]#

Bases: Transformed

`distrax.Transformed` with analytical entropy support.

For \(Y = f(X)\) where \(X \sim \text{base}\),

\[H(Y) = H(X) + \mathbb{E}_X[\log|\det J_f(X)|].\]

The expectation is computed by the bijector’s log_det_expectation() method via Gauss–Hermite quadrature (exact for polynomial integrands, highly accurate for smooth bijectors such as scaled tanh).

entropy(input_hint: Array | ndarray | bool | number | None = None) Array[source]#

Calculates the Shannon entropy (in Nats).

Only works for bijectors with constant Jacobian determinant.

Parameters:

input_hint – an example sample from the base distribution, used to compute the constant forward log-determinant. If not specified, it is computed using a zero array of the shape and dtype of a sample from the base distribution.

Returns:

the entropy of the distribution.

Raises:

NotImplementedError – if bijector’s Jacobian determinant is not known to be constant.

Modules

boxSpace

Implementation of bijector for box space.

freeSpace

Implementation of identity bijector for free space.

maxNormSpace

Implementation of bijector for max Norm space.