Source code for jaxdem.rl.actionSpaces.maxNormSpace

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Implementation of bijector for max Norm space.
"""

import jax
import jax.numpy as jnp

from typing import Tuple, Optional, Dict

import distrax
from distrax._src.bijectors.bijector import Array

from . import ActionSpace


[docs] @ActionSpace.register("MaxNorm") class MaxNormSpace(distrax.Bijector, ActionSpace): r""" **Radial max-norm** constraint for vector actions: scales the radius with a `tanh` squashing while preserving direction. **Mapping (vector case,** :math:`x \in \mathbb{R}^d` .. math:: r = \lVert \vec{x} \rVert_2,\qquad \hat{u} = \begin{cases} \frac{\vec{x}}{r}, & r>0,\\[4pt] 0, & r=0, \end{cases} \qquad y = s \tanh(r) \hat{u}, \quad s = (1-\epsilon) \texttt{max_norm}. Equivalently, :math:`y = b(r)\,x` with :math:`b(r)= s\,\tanh(r)/r` for :math:`r>0`. **Jacobian determinant** For an isotropic radial map :math:`f(x)=b(r)` with :math:`x \in \mathbb{R}^d`, the Jacobian eigenvalues are :math:`b` (multiplicity d-1) on the tangent subspace and :math:`b + r\,b'(r)` on the radial direction, hence .. math:: \bigl|\det J_f(x)\bigr| = b(r)^{\,d-1}\,\bigl(b(r)+r\,b'(r)\bigr) = s^d \left(\frac{\tanh r}{r}\right)^{\!d-1} sech^2 r. Therefore .. math:: \log\lvert\det J_f(x)\rvert = d\log s + (d-1)\bigl(\log\tanh r - \log r\bigr) + \log( sech^2 r), We use the stable identity :math:`\log(sech^2 z)=2 [\log 2 - z - softplus(-2z)]`, which we apply for good numerical behavior. Near :math:`r\approx 0`, we use the second-order expansion .. math:: \log\lvert\det J_f(x)\rvert \approx d\log s - \tfrac{2}{3} r^2 to avoid division by :math:`r`. Parameters ---------- -max_norm : float target radius \(s\) after squashing (default 1.0). We actually use \(s=(1-\varepsilon)\,\texttt{max\_norm}\) to avoid the exact boundary. -eps : float numerical safety margin used near \(r=0\) and \(r\to\infty\). -event_ndims_in : int dimensionality of a *single event* seen by the bijector (defaults to 0 for a scalar transform). -event_ndims_out : Optional[int] standard Distrax/TFP bijector flags. -is_constant_jacobian : bool standard Distrax/TFP bijector flags. -is_constant_log_det : bool standard Distrax/TFP bijector flags. Note ---------- This bijector is **vector-valued** with ``event_ndims_in = 1`` (i.e., it operates on length-\(d\) action vectors as a single event). Do **not** wrap it in `Block` unless you intend to apply it independently to multiple last-axis blocks. """ __slots__ = () def __init__( self, max_norm: float = 1.0, eps: float = 1e-6, event_ndims_in: int = 1, event_ndims_out: Optional[int] = None, is_constant_jacobian: bool = False, is_constant_log_det: Optional[bool] = None, ): super().__init__( event_ndims_in=event_ndims_in, event_ndims_out=event_ndims_out, is_constant_jacobian=is_constant_jacobian, is_constant_log_det=is_constant_log_det, ) self.eps = float(eps) self.max_norm = float(max_norm) @property def kws(self) -> Dict: return dict( max_norm=self.max_norm, eps=self.eps, event_ndims_in=self.event_ndims_in, event_ndims_out=self.event_ndims_out, is_constant_jacobian=self.is_constant_jacobian, is_constant_log_det=self.is_constant_log_det, )
[docs] @staticmethod def sec2_log(r): # r is scalar radius return 2 * (jnp.log(2.0) - r - jax.nn.softplus(-2.0 * r))
[docs] def forward_log_det_jacobian(self, x: Array) -> jax.Array: r = jnp.linalg.norm(x, axis=-1) # shape (...,) x = jnp.atleast_1d(x) # ensures x.ndim >= 1 d = jnp.asarray(x.shape[-1], x.dtype) # scalar, works under jit # Stable pieces log_s = jnp.log((1.0 - self.eps) * self.max_norm + self.eps) log_tanh_r = jnp.log(jnp.tanh(r) + self.eps) log_r = jnp.log(r + self.eps) log_sech2_r = MaxNormSpace.sec2_log(r) main = d * log_s + (d - 1.0) * (log_tanh_r - log_r) + log_sech2_r small = d * log_s - (2.0 / 3.0) * (r * r) return jnp.where(r < self.eps, small, main)
[docs] def forward_and_log_det(self, x: Array) -> Tuple[jax.Array, jax.Array]: r = jnp.linalg.norm(x, axis=-1, keepdims=True) unit = jnp.where(r > 0.0, x / r, jnp.zeros_like(x)) y = (1.0 - self.eps) * self.max_norm * jnp.tanh(r) * unit return y, self.forward_log_det_jacobian(x)
[docs] def inverse_and_log_det(self, y: Array) -> Tuple[jax.Array, jax.Array]: r = jnp.linalg.norm(y, axis=-1, keepdims=True) u = (r / ((1.0 - self.eps) * self.max_norm)).clip( -1.0 + self.eps, 1.0 - self.eps ) unit = jnp.where(r > 0.0, y / r, jnp.zeros_like(y)) x = jnp.arctanh(u) * unit return x, -self.forward_log_det_jacobian(x)
[docs] def same_as(self, other: distrax.Bijector) -> bool: return type(other) is MaxNormSpace
__all__ = ["MaxNormSpace"]