jaxdem.rl.actionSpaces#
Interface for defining bijectors used to constraint the policy probability distribution.
Classes
Registry/namespace for action-space constraints implemented as `distrax.Bijector`s. |
- class jaxdem.rl.actionSpaces.ActionSpace[source]#
Bases:
Factory
Registry/namespace for action-space constraints implemented as `distrax.Bijector`s.
These bijectors are intended to be wrapped around a base policy distribution (e.g., MultivariateNormalDiag) via distrax.Transformed, so that sampling and log-probabilities are correctly adjusted using the bijector’s forward_and_log_det / inverse_and_log_det methods. See Distrax/TFP bijector interface for details on shape semantics and event_ndims_in/out.
Example
To define a custom action space, inherit from
distrax.Bijector
andActionSpace
and implement its abstract methods:>>> @ActionSpace.register("myCustomActionSpace") >>> class MyCustomActionSpace(distrax.Bijector, ActionSpace): ...
- class jaxdem.rl.actionSpaces.FreeSpace(*args, **kwargs)[source][source]#
Bases:
Bijector
,ActionSpace
Identity constraint (no transform).
Mapping
\[y = f(x) = x, \qquad x = f^{-1}(y) = y.\]Jacobian
\[J_f(x) = I,\qquad \log\lvert\det J_f(x)\rvert = 0, \qquad \log\lvert\det J_{f^{-1}}(y)\rvert = 0.\]- Parameters:
-event_ndims_in (int) – dimensionality of a single event seen by the bijector (defaults to 0 for a scalar transform).
-event_ndims_out (Optional[int]) – standard Distrax/TFP bijector flags.
-is_constant_jacobian (bool) – standard Distrax/TFP bijector flags.
-is_constant_log_det (bool) – standard Distrax/TFP bijector flags.
Note
This bijector is scalar (
event_ndims_in = 0
). For vector actions, needs to be wrap it with ``distrax.Block(bijector, ndims=1)`. Let the model do that for you!- forward_and_log_det(x: Array | ndarray | bool | number) Tuple[Array | ndarray | bool | number, Array] [source][source]#
Computes y = f(x) and log|det J(f)(x)|.
- inverse_and_log_det(y: Array | ndarray | bool | number) Tuple[Array | ndarray | bool | number, Array] [source][source]#
Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.
- class jaxdem.rl.actionSpaces.BoxSpace(*args, **kwargs)[source][source]#
Bases:
Bijector
,ActionSpace
Elementwise box constraint implemented with a scaled tanh.
Mapping (componentwise)
\[y_i \;=\; c_i + h_i\,\tanh\!\left(\frac{x_i}{w}\right), \qquad c_i=\tfrac{1}{2}(x_{\min,i}+x_{\max,i}), \quad h_i=\tfrac{1-\varepsilon}{2}(x_{\max,i}-x_{\min,i}),\]with width parameter (\(w>0\)) and small (\(\epsilon>0\)) for numerical safety.
Jacobian (componentwise) For each component,
\[\frac{\partial y_i}{\partial x_i} = \frac{h_i}{w} sech^2 \left(\frac{x_i}{w}\right), \qquad \log\left| \frac{\partial y_i}{\partial x_i} \right| = \log h_i - \log w + \log\!\big(sech^2(\frac{x_i}{w})\big).\]Using the stable identity \(\log(sech^2 z)=2 [\log 2 - z - softplus(-2z)]\), which we apply for good numerical behavior.
- Parameters:
-x_min (jax.Array) – Elementwise lower bounds of the distribution.
-x_max (jax.Array) – Elementwise upper bounds of the distribution. Must satisfy x_max > x_min elementwise.
-width (float) – slope control.
-eps (float) – Small offset to avoid arctanh divergence close to bounds.
-event_ndims_in (int) – dimensionality of a single event seen by the bijector (defaults to 0 for a scalar transform).
-event_ndims_out (Optional[int]) – standard Distrax/TFP bijector flags.
-is_constant_jacobian (bool) – standard Distrax/TFP bijector flags.
-is_constant_log_det (bool) – standard Distrax/TFP bijector flags.
Note
This bijector is scalar (
event_ndims_in = 0
). For vector actions, needs to be wrap it with ``distrax.Block(bijector, ndims=1)`. Let the model do that for you!- forward_and_log_det(x: Array | ndarray | bool | number) Tuple[Array, Array] [source][source]#
Computes y = f(x) and log|det J(f)(x)|.
- forward_log_det_jacobian(x: Array | ndarray | bool | number) Array [source][source]#
Computes log|det J(f)(x)|. log|dy/dx| = log|half| + log(sech^2 x) Stable log(sech^2 x) = 2*(log(2) - x - softplus(-2x))
- inverse_and_log_det(y: Array | ndarray | bool | number) Tuple[Array, Array] [source][source]#
Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.
- class jaxdem.rl.actionSpaces.MaxNormSpace(*args, **kwargs)[source][source]#
Bases:
Bijector
,ActionSpace
Radial max-norm constraint for vector actions: scales the radius with a tanh squashing while preserving direction.
Mapping (vector case, \(x \in \mathbb{R}^d\)
\[\begin{split}r = \lVert \vec{x} \rVert_2,\qquad \hat{u} = \begin{cases} \frac{\vec{x}}{r}, & r>0,\\[4pt] 0, & r=0, \end{cases} \qquad y = s \tanh(r) \hat{u}, \quad s = (1-\epsilon) \texttt{max_norm}.\end{split}\]Equivalently, \(y = b(r)\,x\) with \(b(r)= s\,\tanh(r)/r\) for \(r>0\).
Jacobian determinant
For an isotropic radial map \(f(x)=b(r)\) with \(x \in \mathbb{R}^d\), the Jacobian eigenvalues are \(b\) (multiplicity d-1) on the tangent subspace and \(b + r\,b'(r)\) on the radial direction, hence
\[\bigl|\det J_f(x)\bigr| = b(r)^{\,d-1}\,\bigl(b(r)+r\,b'(r)\bigr) = s^d \left(\frac{\tanh r}{r}\right)^{\!d-1} sech^2 r.\]Therefore
\[\log\lvert\det J_f(x)\rvert = d\log s + (d-1)\bigl(\log\tanh r - \log r\bigr) + \log( sech^2 r),\]We use the stable identity \(\log(sech^2 z)=2 [\log 2 - z - softplus(-2z)]\), which we apply for good numerical behavior.
Near \(r\approx 0\), we use the second-order expansion
\[\log\lvert\det J_f(x)\rvert \approx d\log s - \tfrac{2}{3} r^2\]to avoid division by \(r\).
- Parameters:
-max_norm (float) – target radius (s) after squashing (default 1.0). We actually use (s=(1-varepsilon),texttt{max_norm}) to avoid the exact boundary.
-eps (float) – numerical safety margin used near (r=0) and (rtoinfty).
-event_ndims_in (int) – dimensionality of a single event seen by the bijector (defaults to 0 for a scalar transform).
-event_ndims_out (Optional[int]) – standard Distrax/TFP bijector flags.
-is_constant_jacobian (bool) – standard Distrax/TFP bijector flags.
-is_constant_log_det (bool) – standard Distrax/TFP bijector flags.
Note
This bijector is vector-valued with
event_ndims_in = 1
(i.e., it operates on length-(d) action vectors as a single event). Do not wrap it in Block unless you intend to apply it independently to multiple last-axis blocks.- forward_and_log_det(x: Array | ndarray | bool | number) Tuple[Array, Array] [source][source]#
Computes y = f(x) and log|det J(f)(x)|.
- forward_log_det_jacobian(x: Array | ndarray | bool | number) Array [source][source]#
Computes log|det J(f)(x)|.
- inverse_and_log_det(y: Array | ndarray | bool | number) Tuple[Array, Array] [source][source]#
Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.
Modules
Implementation of bijector for box space. |
|
Implementation of identity bijector for free space. |
|
Implementation of bijector for max Norm space. |