Skip to content
Merged
8 changes: 7 additions & 1 deletion xarray/indexes/nd_point_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,13 @@ class ScipyKDTreeAdapter(TreeAdapter):
_kdtree: KDTree

def __init__(self, points: np.ndarray, options: Mapping[str, Any]):
from scipy.spatial import KDTree
try:
from scipy.spatial import KDTree
except ImportError as err:
raise ImportError(
"`NDPointIndex` requires `scipy` when used with `ScipyKDTreeAdapter`. "
"Please ensure that `scipy` is installed and importable."
) from err

self._kdtree = KDTree(points, **options)

Expand Down
16 changes: 14 additions & 2 deletions xarray/tests/test_nd_point_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,19 @@

import xarray as xr
from xarray.indexes import NDPointIndex
from xarray.tests import assert_identical
from xarray.indexes.nd_point_index import ScipyKDTreeAdapter
from xarray.tests import assert_identical, has_scipy, requires_scipy

pytest.importorskip("scipy")

@pytest.mark.skipif(has_scipy, reason="requires scipy to be missing")
def test_scipy_kdtree_adapter_missing_scipy():
points = np.random.rand(4, 2)

with pytest.raises(ImportError, match=r"scipy"):
ScipyKDTreeAdapter(points, options={})


@requires_scipy
def test_tree_index_init() -> None:
from xarray.indexes.nd_point_index import ScipyKDTreeAdapter

Expand All @@ -26,6 +34,7 @@ def test_tree_index_init() -> None:
assert ds_indexed1.xindexes["xx"].equals(ds_indexed2.xindexes["yy"])


@requires_scipy
def test_tree_index_init_errors() -> None:
xx, yy = np.meshgrid([1.0, 2.0], [3.0, 4.0])
ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)})
Expand All @@ -39,6 +48,7 @@ def test_tree_index_init_errors() -> None:
ds2.set_xindex(("xx", "yy"), NDPointIndex)


@requires_scipy
def test_tree_index_sel() -> None:
xx, yy = np.meshgrid([1.0, 2.0], [3.0, 4.0])
ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}).set_xindex(
Expand Down Expand Up @@ -112,6 +122,7 @@ def test_tree_index_sel() -> None:
assert_identical(actual, expected)


@requires_scipy
def test_tree_index_sel_errors() -> None:
xx, yy = np.meshgrid([1.0, 2.0], [3.0, 4.0])
ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}).set_xindex(
Expand All @@ -137,6 +148,7 @@ def test_tree_index_sel_errors() -> None:
)


@requires_scipy
def test_tree_index_equals() -> None:
xx1, yy1 = np.meshgrid([1.0, 2.0], [3.0, 4.0])
ds1 = xr.Dataset(
Expand Down
Loading