jaxdem.rl.actionSpaces#

Interface for defining bijectors used to constraint the policy probability distribution.

Classes

ActionSpace()

Registry/namespace for action-space constraints implemented as `distrax.Bijector`s.

class jaxdem.rl.actionSpaces.ActionSpace[source]#

Bases: Factory

Registry/namespace for action-space constraints implemented as `distrax.Bijector`s.

These bijectors are intended to be wrapped around a base policy distribution (e.g., MultivariateNormalDiag) via distrax.Transformed, so that sampling and log-probabilities are correctly adjusted using the bijector’s forward_and_log_det / inverse_and_log_det methods. See Distrax/TFP bijector interface for details on shape semantics and event_ndims_in/out.

Example

To define a custom action space, inherit from distrax.Bijector and ActionSpace and implement its abstract methods:

>>> @ActionSpace.register("myCustomActionSpace")
>>> class MyCustomActionSpace(distrax.Bijector, ActionSpace):
        ...
property kws: Dict[str, Any][source]#
class jaxdem.rl.actionSpaces.FreeSpace(*args, **kwargs)[source][source]#

Bases: Bijector, ActionSpace

Identity constraint (no transform).

Mapping

\[y = f(x) = x, \qquad x = f^{-1}(y) = y.\]

Jacobian

\[J_f(x) = I,\qquad \log\lvert\det J_f(x)\rvert = 0, \qquad \log\lvert\det J_{f^{-1}}(y)\rvert = 0.\]
Parameters:
  • -event_ndims_in (int) – dimensionality of a single event seen by the bijector (defaults to 0 for a scalar transform).

  • -event_ndims_out (Optional[int]) – standard Distrax/TFP bijector flags.

  • -is_constant_jacobian (bool) – standard Distrax/TFP bijector flags.

  • -is_constant_log_det (bool) – standard Distrax/TFP bijector flags.

Note

This bijector is scalar (event_ndims_in = 0). For vector actions, needs to be wrap it with ``distrax.Block(bijector, ndims=1)`. Let the model do that for you!

property kws: Dict[str, Any][source]#
forward_and_log_det(x: Array | ndarray | bool | number) Tuple[Array | ndarray | bool | number, Array][source][source]#

Computes y = f(x) and log|det J(f)(x)|.

inverse_and_log_det(y: Array | ndarray | bool | number) Tuple[Array | ndarray | bool | number, Array][source][source]#

Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.

same_as(other: Bijector) bool[source][source]#

Returns True if this bijector is guaranteed to be the same as other.

class jaxdem.rl.actionSpaces.BoxSpace(*args, **kwargs)[source][source]#

Bases: Bijector, ActionSpace

Elementwise box constraint implemented with a scaled tanh.

Mapping (componentwise)

\[y_i \;=\; c_i + h_i\,\tanh\!\left(\frac{x_i}{w}\right), \qquad c_i=\tfrac{1}{2}(x_{\min,i}+x_{\max,i}), \quad h_i=\tfrac{1-\varepsilon}{2}(x_{\max,i}-x_{\min,i}),\]

with width parameter (\(w>0\)) and small (\(\epsilon>0\)) for numerical safety.

Jacobian (componentwise) For each component,

\[\frac{\partial y_i}{\partial x_i} = \frac{h_i}{w} sech^2 \left(\frac{x_i}{w}\right), \qquad \log\left| \frac{\partial y_i}{\partial x_i} \right| = \log h_i - \log w + \log\!\big(sech^2(\frac{x_i}{w})\big).\]

Using the stable identity \(\log(sech^2 z)=2 [\log 2 - z - softplus(-2z)]\), which we apply for good numerical behavior.

Parameters:
  • -x_min (jax.Array) – Elementwise lower bounds of the distribution.

  • -x_max (jax.Array) – Elementwise upper bounds of the distribution. Must satisfy x_max > x_min elementwise.

  • -width (float) – slope control.

  • -eps (float) – Small offset to avoid arctanh divergence close to bounds.

  • -event_ndims_in (int) – dimensionality of a single event seen by the bijector (defaults to 0 for a scalar transform).

  • -event_ndims_out (Optional[int]) – standard Distrax/TFP bijector flags.

  • -is_constant_jacobian (bool) – standard Distrax/TFP bijector flags.

  • -is_constant_log_det (bool) – standard Distrax/TFP bijector flags.

Note

This bijector is scalar (event_ndims_in = 0). For vector actions, needs to be wrap it with ``distrax.Block(bijector, ndims=1)`. Let the model do that for you!

property kws: Dict[str, Any][source]#
static sec2_log(x: Array) Array[source][source]#
forward_log_det_jacobian(x: Array | ndarray | bool | number) Array[source][source]#

Computes log|det J(f)(x)|. log|dy/dx| = log|half| + log(sech^2 x) Stable log(sech^2 x) = 2*(log(2) - x - softplus(-2x))

forward_and_log_det(x: Array | ndarray | bool | number) Tuple[Array, Array][source][source]#

Computes y = f(x) and log|det J(f)(x)|.

inverse_and_log_det(y: Array | ndarray | bool | number) Tuple[Array, Array][source][source]#

Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.

same_as(other: Bijector) bool[source][source]#

Returns True if this bijector is guaranteed to be the same as other.

class jaxdem.rl.actionSpaces.MaxNormSpace(*args, **kwargs)[source][source]#

Bases: Bijector, ActionSpace

Radial max-norm constraint for vector actions: scales the radius with a tanh squashing while preserving direction.

Mapping (vector case, \(x \in \mathbb{R}^d\)

\[\begin{split}r = \lVert \vec{x} \rVert_2,\qquad \hat{u} = \begin{cases} \frac{\vec{x}}{r}, & r>0,\\[4pt] 0, & r=0, \end{cases} \qquad y = s \tanh(r) \hat{u}, \quad s = (1-\epsilon) \texttt{max_norm}.\end{split}\]

Equivalently, \(y = b(r)\,x\) with \(b(r)= s\,\tanh(r)/r\) for \(r>0\).

Jacobian determinant

For an isotropic radial map \(f(x)=b(r)\) with \(x \in \mathbb{R}^d\), the Jacobian eigenvalues are \(b\) (multiplicity d-1) on the tangent subspace and \(b + r\,b'(r)\) on the radial direction, hence

\[\bigl|\det J_f(x)\bigr| = b(r)^{\,d-1}\,\bigl(b(r)+r\,b'(r)\bigr) = s^d \left(\frac{\tanh r}{r}\right)^{\!d-1} sech^2 r.\]

Therefore

\[\log\lvert\det J_f(x)\rvert = d\log s + (d-1)\bigl(\log\tanh r - \log r\bigr) + \log( sech^2 r),\]

We use the stable identity \(\log(sech^2 z)=2 [\log 2 - z - softplus(-2z)]\), which we apply for good numerical behavior.

Near \(r\approx 0\), we use the second-order expansion

\[\log\lvert\det J_f(x)\rvert \approx d\log s - \tfrac{2}{3} r^2\]

to avoid division by \(r\).

Parameters:
  • -max_norm (float) – target radius (s) after squashing (default 1.0). We actually use (s=(1-varepsilon),texttt{max_norm}) to avoid the exact boundary.

  • -eps (float) – numerical safety margin used near (r=0) and (rtoinfty).

  • -event_ndims_in (int) – dimensionality of a single event seen by the bijector (defaults to 0 for a scalar transform).

  • -event_ndims_out (Optional[int]) – standard Distrax/TFP bijector flags.

  • -is_constant_jacobian (bool) – standard Distrax/TFP bijector flags.

  • -is_constant_log_det (bool) – standard Distrax/TFP bijector flags.

Note

This bijector is vector-valued with event_ndims_in = 1 (i.e., it operates on length-(d) action vectors as a single event). Do not wrap it in Block unless you intend to apply it independently to multiple last-axis blocks.

property kws: Dict[str, Any][source]#
static sec2_log(r: Array) Array[source][source]#
forward_log_det_jacobian(x: Array | ndarray | bool | number) Array[source][source]#

Computes log|det J(f)(x)|.

forward_and_log_det(x: Array | ndarray | bool | number) Tuple[Array, Array][source][source]#

Computes y = f(x) and log|det J(f)(x)|.

inverse_and_log_det(y: Array | ndarray | bool | number) Tuple[Array, Array][source][source]#

Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.

same_as(other: Bijector) bool[source][source]#

Returns True if this bijector is guaranteed to be the same as other.

Modules

boxSpace

Implementation of bijector for box space.

freeSpace

Implementation of identity bijector for free space.

maxNormSpace

Implementation of bijector for max Norm space.