jaxdem.rl.actionSpaces.boxSpace#
Implementation of bijector for box space.
Classes
|
Elementwise box constraint implemented with a scaled tanh. |
- class jaxdem.rl.actionSpaces.boxSpace.BoxSpace(*args, **kwargs)[source][source]#
Bases:
Bijector
,ActionSpace
Elementwise box constraint implemented with a scaled tanh.
Mapping (componentwise)
\[y_i \;=\; c_i + h_i\,\tanh\!\left(\frac{x_i}{w}\right), \qquad c_i=\tfrac{1}{2}(x_{\min,i}+x_{\max,i}), \quad h_i=\tfrac{1-\varepsilon}{2}(x_{\max,i}-x_{\min,i}),\]with width parameter (\(w>0\)) and small (\(\epsilon>0\)) for numerical safety.
Jacobian (componentwise) For each component,
\[\frac{\partial y_i}{\partial x_i} = \frac{h_i}{w} sech^2 \left(\frac{x_i}{w}\right), \qquad \log\left| \frac{\partial y_i}{\partial x_i} \right| = \log h_i - \log w + \log\!\big(sech^2(\frac{x_i}{w})\big).\]Using the stable identity \(\log(sech^2 z)=2 [\log 2 - z - softplus(-2z)]\), which we apply for good numerical behavior.
- Parameters:
-x_min (jax.Array) – Elementwise lower bounds of the distribution.
-x_max (jax.Array) – Elementwise upper bounds of the distribution. Must satisfy x_max > x_min elementwise.
-width (float) – slope control.
-eps (float) – Small offset to avoid arctanh divergence close to bounds.
-event_ndims_in (int) – dimensionality of a single event seen by the bijector (defaults to 0 for a scalar transform).
-event_ndims_out (Optional[int]) – standard Distrax/TFP bijector flags.
-is_constant_jacobian (bool) – standard Distrax/TFP bijector flags.
-is_constant_log_det (bool) – standard Distrax/TFP bijector flags.
Note
This bijector is scalar (
event_ndims_in = 0
). For vector actions, needs to be wrap it with ``distrax.Block(bijector, ndims=1)`. Let the model do that for you!- forward_log_det_jacobian(x: Array | ndarray | bool | number) Array [source][source]#
Computes log|det J(f)(x)|. log|dy/dx| = log|half| + log(sech^2 x) Stable log(sech^2 x) = 2*(log(2) - x - softplus(-2x))
- forward_and_log_det(x: Array | ndarray | bool | number) Tuple[Array, Array] [source][source]#
Computes y = f(x) and log|det J(f)(x)|.
- inverse_and_log_det(y: Array | ndarray | bool | number) Tuple[Array, Array] [source][source]#
Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.