This repository contains a JAX implementation of spherical bessel functions that is fully compatible with jax.jit and jax.grad. It is based on the implementation by @sousaw in NeuralIL.
The package can be installed from PyPi using pip:
pip install psbessax