jaxdem.rl.actionSpaces.maxNormSpace#

Implementation of bijector for max Norm space.

Classes

MaxNormSpace(*args, **kwargs)

Radial max-norm constraint for vector actions: scales the radius with a tanh squashing while preserving direction.

class jaxdem.rl.actionSpaces.maxNormSpace.MaxNormSpace(*args, **kwargs)[source][source]#

Bases: Bijector, ActionSpace

Radial max-norm constraint for vector actions: scales the radius with a tanh squashing while preserving direction.

Mapping (vector case, \(x \in \mathbb{R}^d\)

\[\begin{split}r = \lVert \vec{x} \rVert_2,\qquad \hat{u} = \begin{cases} \frac{\vec{x}}{r}, & r>0,\\[4pt] 0, & r=0, \end{cases} \qquad y = s \tanh(r) \hat{u}, \quad s = (1-\epsilon) \texttt{max_norm}.\end{split}\]

Equivalently, \(y = b(r)\,x\) with \(b(r)= s\,\tanh(r)/r\) for \(r>0\).

Jacobian determinant

For an isotropic radial map \(f(x)=b(r)\) with \(x \in \mathbb{R}^d\), the Jacobian eigenvalues are \(b\) (multiplicity d-1) on the tangent subspace and \(b + r\,b'(r)\) on the radial direction, hence

\[\bigl|\det J_f(x)\bigr| = b(r)^{\,d-1}\,\bigl(b(r)+r\,b'(r)\bigr) = s^d \left(\frac{\tanh r}{r}\right)^{\!d-1} sech^2 r.\]

Therefore

\[\log\lvert\det J_f(x)\rvert = d\log s + (d-1)\bigl(\log\tanh r - \log r\bigr) + \log( sech^2 r),\]

We use the stable identity \(\log(sech^2 z)=2 [\log 2 - z - softplus(-2z)]\), which we apply for good numerical behavior.

Near \(r\approx 0\), we use the second-order expansion

\[\log\lvert\det J_f(x)\rvert \approx d\log s - \tfrac{2}{3} r^2\]

to avoid division by \(r\).

Parameters:
  • -max_norm (float) – target radius (s) after squashing (default 1.0). We actually use (s=(1-varepsilon),texttt{max_norm}) to avoid the exact boundary.

  • -eps (float) – numerical safety margin used near (r=0) and (rtoinfty).

  • -event_ndims_in (int) – dimensionality of a single event seen by the bijector (defaults to 0 for a scalar transform).

  • -event_ndims_out (Optional[int]) – standard Distrax/TFP bijector flags.

  • -is_constant_jacobian (bool) – standard Distrax/TFP bijector flags.

  • -is_constant_log_det (bool) – standard Distrax/TFP bijector flags.

Note

This bijector is vector-valued with event_ndims_in = 1 (i.e., it operates on length-(d) action vectors as a single event). Do not wrap it in Block unless you intend to apply it independently to multiple last-axis blocks.

property kws: Dict[source]#
static sec2_log(r)[source][source]#
forward_log_det_jacobian(x: Array | ndarray | bool | number) Array[source][source]#

Computes log|det J(f)(x)|.

forward_and_log_det(x: Array | ndarray | bool | number) Tuple[Array, Array][source][source]#

Computes y = f(x) and log|det J(f)(x)|.

inverse_and_log_det(y: Array | ndarray | bool | number) Tuple[Array, Array][source][source]#

Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.

same_as(other: Bijector) bool[source][source]#

Returns True if this bijector is guaranteed to be the same as other.

classmethod registry_name() str[source]#
property type_name: str[source]#