Source code for jaxdem.analysis.bessel

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project - https://github.com/cdelv/JaxDEM
"""Bessel functions for JAX.

Originally adapted from:
`https://github.com/benjaminpope/sibylla/blob/main/notebooks/bessel_test.ipynb`

Note:
This module does *not* change global JAX configuration (e.g. x64 enablement).
If you want 64-bit execution, set it in your application before importing JaxDEM:

    >>> import jax
    >>> jax.config.update("jax_enable_x64", True)
"""

from __future__ import annotations

from typing import TypeAlias

import jax.numpy as jnp
from jax import Array
from jax.typing import ArrayLike

JaxArray: TypeAlias = Array

_RP1 = jnp.array(
    [
        -8.99971225705559398224e8,
        4.52228297998194034323e11,
        -7.27494245221818276015e13,
        3.68295732863852883286e15,
    ]
)

_RQ1 = jnp.array(
    [
        1.0,
        6.20836478118054335476e2,
        2.56987256757748830383e5,
        8.35146791431949253037e7,
        2.21511595479792499675e10,
        4.74914122079991414898e12,
        7.84369607876235854894e14,
        8.95222336184627338078e16,
        5.32278620332680085395e18,
    ]
)

_PP1 = jnp.array(
    [
        7.62125616208173112003e-4,
        7.31397056940917570436e-2,
        1.12719608129684925192e0,
        5.11207951146807644818e0,
        8.42404590141772420927e0,
        5.21451598682361504063e0,
        1.00000000000000000254e0,
    ]
)

_PQ1 = jnp.array(
    [
        5.71323128072548699714e-4,
        6.88455908754495404082e-2,
        1.10514232634061696926e0,
        5.07386386128601488557e0,
        8.39985554327604159757e0,
        5.20982848682361821619e0,
        9.99999999999999997461e-1,
    ]
)

_QP1 = jnp.array(
    [
        5.10862594750176621635e-2,
        4.98213872951233449420e0,
        7.58238284132545283818e1,
        3.66779609360150777800e2,
        7.10856304998926107277e2,
        5.97489612400613639965e2,
        2.11688757100572135698e2,
        2.52070205858023719784e1,
    ]
)

_QQ1 = jnp.array(
    [
        1.0,
        7.42373277035675149943e1,
        1.05644886038262816351e3,
        4.98641058337653607651e3,
        9.56231892404756170795e3,
        7.99704160447350683650e3,
        2.82619278517639096600e3,
        3.36093607810698293419e2,
    ]
)

_YP1 = jnp.array(
    [
        1.26320474790178026440e9,
        -6.47355876379160291031e11,
        1.14509511541823727583e14,
        -8.12770255501325109621e15,
        2.02439475713594898196e17,
        -7.78877196265950026825e17,
    ]
)
_YQ1 = jnp.array(
    [
        5.94301592346128195359e2,
        2.35564092943068577943e5,
        7.34811944459721705660e7,
        1.87601316108706159478e10,
        3.88231277496238566008e12,
        6.20557727146953693363e14,
        6.87141087355300489866e16,
        3.97270608116560655612e18,
    ]
)

_Z1 = 1.46819706421238932572e1
_Z2 = 4.92184563216946036703e1
_PIO4 = 0.78539816339744830962  # pi/4
_THPIO4 = 2.35619449019234492885  # 3*pi/4
_SQ2OPI = 0.79788456080286535588  # sqrt(2/pi)

_PP0 = jnp.array(
    [
        7.96936729297347051624e-4,
        8.28352392107440799803e-2,
        1.23953371646414299388e0,
        5.44725003058768775090e0,
        8.74716500199817011941e0,
        5.30324038235394892183e0,
        9.99999999999999997821e-1,
    ]
)

_PQ0 = jnp.array(
    [
        9.24408810558863637013e-4,
        8.56288474354474431428e-2,
        1.25352743901058953537e0,
        5.47097740330417105182e0,
        8.76190883237069594232e0,
        5.30605288235394617618e0,
        1.00000000000000000218e0,
    ]
)

_QP0 = jnp.array(
    [
        -1.13663838898469149931e-2,
        -1.28252718670509318512e0,
        -1.95539544257735972385e1,
        -9.32060152123768231369e1,
        -1.77681167980488050595e2,
        -1.47077505154951170175e2,
        -5.14105326766599330220e1,
        -6.05014350600728481186e0,
    ]
)

_QQ0 = jnp.array(
    [
        1.0,
        6.43178256118178023184e1,
        8.56430025976980587198e2,
        3.88240183605401609683e3,
        7.24046774195652478189e3,
        5.93072701187316984827e3,
        2.06209331660327847417e3,
        2.42005740240291393179e2,
    ]
)

_YP0 = jnp.array(
    [
        1.55924367855235737965e4,
        -1.46639295903971606143e7,
        5.43526477051876500413e9,
        -9.82136065717911466409e11,
        8.75906394395366999549e13,
        -3.46628303384729719441e15,
        4.42733268572569800351e16,
        -1.84950800436986690637e16,
    ]
)

_YQ0 = jnp.array(
    [
        1.04128353664259848412e3,
        6.26107330137134956842e5,
        2.68919633393814121987e8,
        8.64002487103935000337e10,
        2.02979612750105546709e13,
        3.17157752842975028269e15,
        2.50596256172653059228e17,
    ]
)

_DR10 = 5.78318596294678452118e0
_DR20 = 3.04712623436620863991e1

_RP0 = jnp.array(
    [
        -4.79443220978201773821e9,
        1.95617491946556577543e12,
        -2.49248344360967716204e14,
        9.70862251047306323952e15,
    ]
)

_RQ0 = jnp.array(
    [
        1.0,
        4.99563147152651017219e2,
        1.73785401676374683123e5,
        4.84409658339962045305e7,
        1.11855537045356834862e10,
        2.11277520115489217587e12,
        3.10518229857422583814e14,
        3.18121955943204943306e16,
        1.71086294081043136091e18,
    ]
)


def _j1_small(x: ArrayLike) -> JaxArray:
    z = x * x
    w = jnp.polyval(_RP1, z) / jnp.polyval(_RQ1, z)
    w = w * x * (z - _Z1) * (z - _Z2)
    return w


def _j1_large_c(x: ArrayLike) -> JaxArray:
    w = 5.0 / x
    z = w * w
    p = jnp.polyval(_PP1, z) / jnp.polyval(_PQ1, z)
    q = jnp.polyval(_QP1, z) / jnp.polyval(_QQ1, z)
    xn = x - _THPIO4
    p = p * jnp.cos(xn) - w * q * jnp.sin(xn)
    return p * _SQ2OPI / jnp.sqrt(x)


[docs] def j1(x: ArrayLike) -> JaxArray: """ Bessel function of order one - using the implementation from CEPHES, translated to Jax. """ return jnp.sign(x) * jnp.where( jnp.abs(x) < 5.0, _j1_small(jnp.abs(x)), _j1_large_c(jnp.abs(x)) )
def _j0_small(x: JaxArray) -> JaxArray: """ Implementation of J0 for x < 5 """ z = x * x # if x < 1.0e-5: # return 1.0 - z/4.0 p = (z - _DR10) * (z - _DR20) p = p * jnp.polyval(_RP0, z) / jnp.polyval(_RQ0, z) return jnp.where(x < 1e-5, 1 - z / 4.0, p) def _j0_large(x: ArrayLike) -> JaxArray: """ Implementation of J0 for x >= 5 """ w = 5.0 / x q = 25.0 / (x * x) p = jnp.polyval(_PP0, q) / jnp.polyval(_PQ0, q) q = jnp.polyval(_QP0, q) / jnp.polyval(_QQ0, q) xn = x - _PIO4 p = p * jnp.cos(xn) - w * q * jnp.sin(xn) return p * _SQ2OPI / jnp.sqrt(x)
[docs] def j0(x: ArrayLike) -> JaxArray: """ Implementation of J0 for all x in Jax """ return jnp.where(jnp.abs(x) < 5.0, _j0_small(jnp.abs(x)), _j0_large(jnp.abs(x)))
__all__ = ["j0", "j1"]