jaxdem.rl.actionSpaces.boxSpace#

Implementation of bijector for box space.

Classes

BoxSpace(*args, **kwargs)

Elementwise box constraint implemented with a scaled tanh.

class jaxdem.rl.actionSpaces.boxSpace.BoxSpace(*args, **kwargs)[source][source]#

Bases: Bijector, ActionSpace

Elementwise box constraint implemented with a scaled tanh.

Mapping (componentwise)

\[y_i \;=\; c_i + h_i\,\tanh\!\left(\frac{x_i}{w}\right), \qquad c_i=\tfrac{1}{2}(x_{\min,i}+x_{\max,i}), \quad h_i=\tfrac{1-\varepsilon}{2}(x_{\max,i}-x_{\min,i}),\]

with width parameter (\(w>0\)) and small (\(\epsilon>0\)) for numerical safety.

Jacobian (componentwise) For each component,

\[\frac{\partial y_i}{\partial x_i} = \frac{h_i}{w} sech^2 \left(\frac{x_i}{w}\right), \qquad \log\left| \frac{\partial y_i}{\partial x_i} \right| = \log h_i - \log w + \log\!\big(sech^2(\frac{x_i}{w})\big).\]

Using the stable identity \(\log(sech^2 z)=2 [\log 2 - z - softplus(-2z)]\), which we apply for good numerical behavior.

Parameters:
  • -x_min (jax.Array) – Elementwise lower bounds of the distribution.

  • -x_max (jax.Array) – Elementwise upper bounds of the distribution. Must satisfy x_max > x_min elementwise.

  • -width (float) – slope control.

  • -eps (float) – Small offset to avoid arctanh divergence close to bounds.

  • -event_ndims_in (int) – dimensionality of a single event seen by the bijector (defaults to 0 for a scalar transform).

  • -event_ndims_out (Optional[int]) – standard Distrax/TFP bijector flags.

  • -is_constant_jacobian (bool) – standard Distrax/TFP bijector flags.

  • -is_constant_log_det (bool) – standard Distrax/TFP bijector flags.

Note

This bijector is scalar (event_ndims_in = 0). For vector actions, needs to be wrap it with ``distrax.Block(bijector, ndims=1)`. Let the model do that for you!

property kws: Dict[source]#
static sec2_log(x)[source][source]#
forward_log_det_jacobian(x: Array | ndarray | bool | number) Array[source][source]#

Computes log|det J(f)(x)|. log|dy/dx| = log|half| + log(sech^2 x) Stable log(sech^2 x) = 2*(log(2) - x - softplus(-2x))

forward_and_log_det(x: Array | ndarray | bool | number) Tuple[Array, Array][source][source]#

Computes y = f(x) and log|det J(f)(x)|.

inverse_and_log_det(y: Array | ndarray | bool | number) Tuple[Array, Array][source][source]#

Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.

same_as(other: Bijector) bool[source][source]#

Returns True if this bijector is guaranteed to be the same as other.

classmethod registry_name() str[source]#
property type_name: str[source]#