jaxdem.rl.actionSpaces#

Interface for defining bijectors used to constraint the policy probability distribution.

Classes

ActionSpace()

Registry/namespace for action-space constraints implemented as `distrax.Bijector`s.

class jaxdem.rl.actionSpaces.ActionSpace[source]#

Bases: Factory

Registry/namespace for action-space constraints implemented as `distrax.Bijector`s.

These bijectors are intended to be wrapped around a base policy distribution (e.g., MultivariateNormalDiag) via distrax.Transformed, so that sampling and log-probabilities are correctly adjusted using the bijector’s forward_and_log_det / inverse_and_log_det methods. See Distrax/TFP bijector interface for details on shape semantics and event_ndims_in/out.

Example

To define a custom action space, inherit from distrax.Bijector and ActionSpace and implement its abstract methods:

>>> @ActionSpace.register("myCustomActionSpace")
>>> class MyCustomActionSpace(distrax.Bijector, ActionSpace):
        ...
property kws: Dict[source]#
class jaxdem.rl.actionSpaces.FreeSpace(*args, **kwargs)[source][source]#

Bases: Bijector, ActionSpace

Identity constraint (no transform).

Mapping

\[y = f(x) = x, \qquad x = f^{-1}(y) = y.\]

Jacobian

\[J_f(x) = I,\qquad \log\lvert\det J_f(x)\rvert = 0, \qquad \log\lvert\det J_{f^{-1}}(y)\rvert = 0.\]
Parameters:
  • -event_ndims_in (int) – dimensionality of a single event seen by the bijector (defaults to 0 for a scalar transform).

  • -event_ndims_out (Optional[int]) – standard Distrax/TFP bijector flags.

  • -is_constant_jacobian (bool) – standard Distrax/TFP bijector flags.

  • -is_constant_log_det (bool) – standard Distrax/TFP bijector flags.

Note

This bijector is scalar (event_ndims_in = 0). For vector actions, needs to be wrap it with ``distrax.Block(bijector, ndims=1)`. Let the model do that for you!

forward_and_log_det(x: Array | ndarray | bool | number) Tuple[Array | ndarray | bool | number, Array][source][source]#

Computes y = f(x) and log|det J(f)(x)|.

inverse_and_log_det(y: Array | ndarray | bool | number) Tuple[Array | ndarray | bool | number, Array][source][source]#

Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.

property kws: Dict[source]#
classmethod registry_name() str[source]#
same_as(other: Bijector) bool[source][source]#

Returns True if this bijector is guaranteed to be the same as other.

property type_name: str[source]#
class jaxdem.rl.actionSpaces.BoxSpace(*args, **kwargs)[source][source]#

Bases: Bijector, ActionSpace

Elementwise box constraint implemented with a scaled tanh.

Mapping (componentwise)

\[y_i \;=\; c_i + h_i\,\tanh\!\left(\frac{x_i}{w}\right), \qquad c_i=\tfrac{1}{2}(x_{\min,i}+x_{\max,i}), \quad h_i=\tfrac{1-\varepsilon}{2}(x_{\max,i}-x_{\min,i}),\]

with width parameter (\(w>0\)) and small (\(\epsilon>0\)) for numerical safety.

Jacobian (componentwise) For each component,

\[\frac{\partial y_i}{\partial x_i} = \frac{h_i}{w} sech^2 \left(\frac{x_i}{w}\right), \qquad \log\left| \frac{\partial y_i}{\partial x_i} \right| = \log h_i - \log w + \log\!\big(sech^2(\frac{x_i}{w})\big).\]

Using the stable identity \(\log(sech^2 z)=2 [\log 2 - z - softplus(-2z)]\), which we apply for good numerical behavior.

Parameters:
  • -x_min (jax.Array) – Elementwise lower bounds of the distribution.

  • -x_max (jax.Array) – Elementwise upper bounds of the distribution. Must satisfy x_max > x_min elementwise.

  • -width (float) – slope control.

  • -eps (float) – Small offset to avoid arctanh divergence close to bounds.

  • -event_ndims_in (int) – dimensionality of a single event seen by the bijector (defaults to 0 for a scalar transform).

  • -event_ndims_out (Optional[int]) – standard Distrax/TFP bijector flags.

  • -is_constant_jacobian (bool) – standard Distrax/TFP bijector flags.

  • -is_constant_log_det (bool) – standard Distrax/TFP bijector flags.

Note

This bijector is scalar (event_ndims_in = 0). For vector actions, needs to be wrap it with ``distrax.Block(bijector, ndims=1)`. Let the model do that for you!

forward_and_log_det(x: Array | ndarray | bool | number) Tuple[Array, Array][source][source]#

Computes y = f(x) and log|det J(f)(x)|.

forward_log_det_jacobian(x: Array | ndarray | bool | number) Array[source][source]#

Computes log|det J(f)(x)|. log|dy/dx| = log|half| + log(sech^2 x) Stable log(sech^2 x) = 2*(log(2) - x - softplus(-2x))

inverse_and_log_det(y: Array | ndarray | bool | number) Tuple[Array, Array][source][source]#

Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.

property kws: Dict[source]#
classmethod registry_name() str[source]#
same_as(other: Bijector) bool[source][source]#

Returns True if this bijector is guaranteed to be the same as other.

static sec2_log(x)[source][source]#
property type_name: str[source]#
class jaxdem.rl.actionSpaces.MaxNormSpace(*args, **kwargs)[source][source]#

Bases: Bijector, ActionSpace

Radial max-norm constraint for vector actions: scales the radius with a tanh squashing while preserving direction.

Mapping (vector case, \(x \in \mathbb{R}^d\)

\[\begin{split}r = \lVert \vec{x} \rVert_2,\qquad \hat{u} = \begin{cases} \frac{\vec{x}}{r}, & r>0,\\[4pt] 0, & r=0, \end{cases} \qquad y = s \tanh(r) \hat{u}, \quad s = (1-\epsilon) \texttt{max_norm}.\end{split}\]

Equivalently, \(y = b(r)\,x\) with \(b(r)= s\,\tanh(r)/r\) for \(r>0\).

Jacobian determinant

For an isotropic radial map \(f(x)=b(r)\) with \(x \in \mathbb{R}^d\), the Jacobian eigenvalues are \(b\) (multiplicity d-1) on the tangent subspace and \(b + r\,b'(r)\) on the radial direction, hence

\[\bigl|\det J_f(x)\bigr| = b(r)^{\,d-1}\,\bigl(b(r)+r\,b'(r)\bigr) = s^d \left(\frac{\tanh r}{r}\right)^{\!d-1} sech^2 r.\]

Therefore

\[\log\lvert\det J_f(x)\rvert = d\log s + (d-1)\bigl(\log\tanh r - \log r\bigr) + \log( sech^2 r),\]

We use the stable identity \(\log(sech^2 z)=2 [\log 2 - z - softplus(-2z)]\), which we apply for good numerical behavior.

Near \(r\approx 0\), we use the second-order expansion

\[\log\lvert\det J_f(x)\rvert \approx d\log s - \tfrac{2}{3} r^2\]

to avoid division by \(r\).

Parameters:
  • -max_norm (float) – target radius (s) after squashing (default 1.0). We actually use (s=(1-varepsilon),texttt{max_norm}) to avoid the exact boundary.

  • -eps (float) – numerical safety margin used near (r=0) and (rtoinfty).

  • -event_ndims_in (int) – dimensionality of a single event seen by the bijector (defaults to 0 for a scalar transform).

  • -event_ndims_out (Optional[int]) – standard Distrax/TFP bijector flags.

  • -is_constant_jacobian (bool) – standard Distrax/TFP bijector flags.

  • -is_constant_log_det (bool) – standard Distrax/TFP bijector flags.

Note

This bijector is vector-valued with event_ndims_in = 1 (i.e., it operates on length-(d) action vectors as a single event). Do not wrap it in Block unless you intend to apply it independently to multiple last-axis blocks.

forward_and_log_det(x: Array | ndarray | bool | number) Tuple[Array, Array][source][source]#

Computes y = f(x) and log|det J(f)(x)|.

forward_log_det_jacobian(x: Array | ndarray | bool | number) Array[source][source]#

Computes log|det J(f)(x)|.

inverse_and_log_det(y: Array | ndarray | bool | number) Tuple[Array, Array][source][source]#

Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.

property kws: Dict[source]#
classmethod registry_name() str[source]#
same_as(other: Bijector) bool[source][source]#

Returns True if this bijector is guaranteed to be the same as other.

static sec2_log(r)[source][source]#
property type_name: str[source]#

Modules

boxSpace

Implementation of bijector for box space.

freeSpace

Implementation of identity bijector for free space.

maxNormSpace

Implementation of bijector for max Norm space.