Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions desc/_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Various checks to automate installation issues and decrease user support requests.

These tests run when DESC boots to notify users of potential issues
or to replace warning messages in downstream libraries with more
verbose warnings.
"""

import warnings

import numpy as np
from jax import grad

from desc.backend import jnp, rfft
from desc.integrals._interp_utils import nufft1d2r


def _c_1d(x):
return jnp.cos(7 * x) + jnp.sin(x) - 33.2


@grad
def _true_g_c_1d(xq):
return _c_1d(xq).sum()


def check_jax_finufft(func=_c_1d, g_func=_true_g_c_1d):
"""Runs tests/test_interp_utils.py::TestFastInterp::test_non_uniform_real_FFT."""
n = 15
f = 2 * rfft(func(jnp.linspace(0, 2 * jnp.pi, n, endpoint=False)), norm="forward")
f = f.at[..., (0, -1) if (n % 2 == 0) else 0].divide(2)
xq = jnp.array([7.34, 1.10134, 2.28])

msg = (
"If you want to use NUFFTs, follow the DESC installation instructions.\n"
"Otherwise you must pass the parameter nufft_eps=0.\n"
"This applies to effective ripple, Gamma_c, and any other\n"
"computations that involve bounce integrals.\n"
)
# https://github.com/unalmis/jax-finufft/blob/main/tests/interpolation_test.py#L13
RTOL = 2e-6

try:
np.testing.assert_allclose(nufft1d2r(xq, f), func(xq), rtol=RTOL)

@grad
def g(xq):
return nufft1d2r(xq, f, eps=1e-7).sum()

np.testing.assert_allclose(g(xq), g_func(xq), rtol=RTOL)
except NameError:
warnings.warn("\njax-finufft is not installed.\n" + msg)
except NotImplementedError:
warnings.warn("\njax-finufft is not installed on GPU.\n" + msg)
except AssertionError as e:
import jax_finufft

Check warning on line 55 in desc/_checks.py

View check run for this annotation

Codecov / codecov/patch

desc/_checks.py#L50-L55

Added lines #L50 - L55 were not covered by tests

warnings.warn(

Check warning on line 57 in desc/_checks.py

View check run for this annotation

Codecov / codecov/patch

desc/_checks.py#L57

Added line #L57 was not covered by tests
f"\njax-finufft version <= {jax_finufft.__version__} has incorrect maths.\n"
+ "\n\n"
+ e
+ "\n\n"
+ msg
)
14 changes: 6 additions & 8 deletions desc/integrals/_interp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
For example, we prefer to not use Horner's method.
"""

import warnings
from functools import partial

import numpy as np
Expand All @@ -16,13 +15,7 @@
try:
from jax_finufft import nufft2, options
except ImportError:
warnings.warn(
"\njax-finufft is not installed.\n"
"If you want to use NUFFTs, follow the DESC installation instructions.\n"
"Otherwise you must set the parameter nufft_eps to zero\n"
"when computing effective ripple, Gamma_c, and any other\n"
"computations that involve bounce integrals.\n"
)
pass

Check warning on line 18 in desc/integrals/_interp_utils.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/_interp_utils.py#L18

Added line #L18 was not covered by tests

from desc.backend import dct, jnp, rfft, rfft2, take
from desc.integrals.quad_utils import bijection_from_disc
Expand Down Expand Up @@ -972,3 +965,8 @@
return jnp.concatenate(
[jnp.sin(nx[..., n_rfft.size - is_even - 1 : 0 : -1]), jnp.cos(nx)], axis=-1
)


from desc._checks import check_jax_finufft # noqa: E402

check_jax_finufft()
33 changes: 22 additions & 11 deletions tests/test_interp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ class TestFastInterp:

@pytest.mark.unit
@pytest.mark.parametrize("M", [1, 8, 9])
def test_fft_shift(self, M):
@staticmethod
def test_fft_shift(M):
"""Test frequency shifting."""
a = np.fft.rfftfreq(M, 1 / M)
np.testing.assert_allclose(a, np.arange(M // 2 + 1))
Expand All @@ -132,7 +133,8 @@ def test_fft_shift(self, M):
(*_test_inputs_1D[5], False),
],
)
def test_non_uniform_FFT(self, func, n, domain, imag_undersampled):
@staticmethod
def test_non_uniform_FFT(func, n, domain, imag_undersampled):
"""Test non-uniform FFT interpolation."""
x = np.linspace(domain[0], domain[1], n, endpoint=False)
c = func(x)
Expand All @@ -148,7 +150,8 @@ def test_non_uniform_FFT(self, func, n, domain, imag_undersampled):

@pytest.mark.unit
@pytest.mark.parametrize("func, n, domain", _test_inputs_1D)
def test_non_uniform_real_FFT(self, func, n, domain):
@staticmethod
def test_non_uniform_real_FFT(func, n, domain):
"""Test non-uniform real FFT interpolation."""
x = jnp.linspace(domain[0], domain[1], n, endpoint=False)
c = func(x)
Expand All @@ -170,7 +173,8 @@ def true_g(xq):

@pytest.mark.unit
@pytest.mark.parametrize("func, m, n, domain_x, domain_y", _test_inputs_2D)
def test_non_uniform_real_FFT_2D(self, func, m, n, domain_x, domain_y):
@staticmethod
def test_non_uniform_real_FFT_2D(func, m, n, domain_x, domain_y):
"""Test non-uniform real FFT 2D interpolation."""
x = jnp.linspace(domain_x[0], domain_x[1], m, endpoint=False)
y = jnp.linspace(domain_y[0], domain_y[1], n, endpoint=False)
Expand Down Expand Up @@ -206,7 +210,8 @@ def true_g(xq, yq):

@pytest.mark.unit
@pytest.mark.parametrize("func, n, domain", _test_inputs_1D)
def test_non_uniform_real_MMT(self, func, n, domain):
@staticmethod
def test_non_uniform_real_MMT(func, n, domain):
"""Test non-uniform real MMT interpolation."""
x = np.linspace(domain[0], domain[1], n, endpoint=False)
c = func(x)
Expand All @@ -219,7 +224,8 @@ def test_non_uniform_real_MMT(self, func, n, domain):

@pytest.mark.unit
@pytest.mark.parametrize("func, m, n, domain_x, domain_y", _test_inputs_2D)
def test_non_uniform_real_MMT_2D(self, func, m, n, domain_x, domain_y):
@staticmethod
def test_non_uniform_real_MMT_2D(func, m, n, domain_x, domain_y):
"""Test non-uniform real MMT 2D interpolation."""
x = np.linspace(domain_x[0], domain_x[1], m, endpoint=False)
y = np.linspace(domain_y[0], domain_y[1], n, endpoint=False)
Expand Down Expand Up @@ -252,7 +258,8 @@ def test_non_uniform_real_MMT_2D(self, func, m, n, domain_x, domain_y):
)

@pytest.mark.unit
def test_nufft2_vec(self):
@staticmethod
def test_nufft2_vec():
"""Test vectorized JAX-finufft vectorized interpolation."""
func_1, n, domain = _test_inputs_1D[0]
func_2 = lambda x: -77 * np.sin(7 * x) + 18 * np.cos(x) + 100 # noqa: E731
Expand Down Expand Up @@ -285,7 +292,8 @@ def test_nufft2_vec(self):

@pytest.mark.unit
@pytest.mark.parametrize("N", [2, 6, 7])
def test_cheb_pts(self, N):
@staticmethod
def test_cheb_pts(N):
"""Test we use Chebyshev points compatible with DCT."""
np.testing.assert_allclose(cheb_pts(N), chebpts1(N)[::-1], atol=1e-15)
np.testing.assert_allclose(
Expand All @@ -306,7 +314,8 @@ def test_cheb_pts(self, N):
(identity, 4, True),
],
)
def test_dct(self, f, M, lobatto):
@staticmethod
def test_dct(f, M, lobatto):
"""Test discrete cosine transform interpolation.

Parameters
Expand Down Expand Up @@ -354,7 +363,8 @@ def test_dct(self, f, M, lobatto):
"f, M",
[(_f_non_periodic, 5), (_f_non_periodic, 6), (_f_algebraic, 7)],
)
def test_interp_dct(self, f, M):
@staticmethod
def test_interp_dct(f, M):
"""Test non-uniform DCT interpolation."""
c0 = chebinterpolate(f, M - 1)
assert not np.allclose(
Expand Down Expand Up @@ -386,7 +396,8 @@ def test_interp_dct(self, f, M):
"func, m, n",
[(_c_2d, 2 * _c_2d_nyquist_freq()[0] + 1, 2 * _c_2d_nyquist_freq()[1] + 1)],
)
def test_fourier_chebyshev(self, func, m, n):
@staticmethod
def test_fourier_chebyshev(func, m, n):
"""Tests for coverage of FourierChebyshev series."""
x = fourier_pts(m)
y = cheb_pts(n)
Expand Down