jaxdem.rl.actionSpaces.maxNormSpace#
Implementation of bijector for max Norm space.
Classes
|
Radial max-norm constraint for vector actions: scales the radius with a tanh squashing while preserving direction. |
- class jaxdem.rl.actionSpaces.maxNormSpace.MaxNormSpace(*args, **kwargs)[source][source]#
Bases:
Bijector
,ActionSpace
Radial max-norm constraint for vector actions: scales the radius with a tanh squashing while preserving direction.
Mapping (vector case, \(x \in \mathbb{R}^d\)
\[\begin{split}r = \lVert \vec{x} \rVert_2,\qquad \hat{u} = \begin{cases} \frac{\vec{x}}{r}, & r>0,\\[4pt] 0, & r=0, \end{cases} \qquad y = s \tanh(r) \hat{u}, \quad s = (1-\epsilon) \texttt{max_norm}.\end{split}\]Equivalently, \(y = b(r)\,x\) with \(b(r)= s\,\tanh(r)/r\) for \(r>0\).
Jacobian determinant
For an isotropic radial map \(f(x)=b(r)\) with \(x \in \mathbb{R}^d\), the Jacobian eigenvalues are \(b\) (multiplicity d-1) on the tangent subspace and \(b + r\,b'(r)\) on the radial direction, hence
\[\bigl|\det J_f(x)\bigr| = b(r)^{\,d-1}\,\bigl(b(r)+r\,b'(r)\bigr) = s^d \left(\frac{\tanh r}{r}\right)^{\!d-1} sech^2 r.\]Therefore
\[\log\lvert\det J_f(x)\rvert = d\log s + (d-1)\bigl(\log\tanh r - \log r\bigr) + \log( sech^2 r),\]We use the stable identity \(\log(sech^2 z)=2 [\log 2 - z - softplus(-2z)]\), which we apply for good numerical behavior.
Near \(r\approx 0\), we use the second-order expansion
\[\log\lvert\det J_f(x)\rvert \approx d\log s - \tfrac{2}{3} r^2\]to avoid division by \(r\).
- Parameters:
-max_norm (float) – target radius (s) after squashing (default 1.0). We actually use (s=(1-varepsilon),texttt{max_norm}) to avoid the exact boundary.
-eps (float) – numerical safety margin used near (r=0) and (rtoinfty).
-event_ndims_in (int) – dimensionality of a single event seen by the bijector (defaults to 0 for a scalar transform).
-event_ndims_out (Optional[int]) – standard Distrax/TFP bijector flags.
-is_constant_jacobian (bool) – standard Distrax/TFP bijector flags.
-is_constant_log_det (bool) – standard Distrax/TFP bijector flags.
Note
This bijector is vector-valued with
event_ndims_in = 1
(i.e., it operates on length-(d) action vectors as a single event). Do not wrap it in Block unless you intend to apply it independently to multiple last-axis blocks.- forward_log_det_jacobian(x: Array | ndarray | bool | number) Array [source][source]#
Computes log|det J(f)(x)|.
- forward_and_log_det(x: Array | ndarray | bool | number) Tuple[Array, Array] [source][source]#
Computes y = f(x) and log|det J(f)(x)|.
- inverse_and_log_det(y: Array | ndarray | bool | number) Tuple[Array, Array] [source][source]#
Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.