Skip to content

Commit 34efef2

Browse files
authored
Fix performance regression in interp from #9881 (#10370)
1 parent 7f6add6 commit 34efef2

File tree

3 files changed

+37
-6
lines changed

3 files changed

+37
-6
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ Performance
6969
in :py:class:`~xarray.indexing.VectorizedIndexer` and :py:class:`~xarray.indexing.OuterIndexer`
7070
(:issue:`10316`).
7171
By `Jesse Rusak <https://github.com/jder>`_.
72+
- Fix performance regression in interp where more data was loaded than was necessary. (:issue:`10287`).
73+
By `Deepak Cherian <https://github.com/dcherian>`_.
7274
- Speed up encoding of :py:class:`cftime.datetime` objects by roughly a factor
7375
of three (:pull:`8324`). By `Antoine Gibek <https://github.com/antscloud>`_.
7476

xarray/core/dataset.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3804,15 +3804,16 @@ def _validate_interp_indexer(x, new_x):
38043804
for k, v in indexers.items()
38053805
}
38063806

3807+
# optimization: subset to coordinate range of the target index
3808+
if method in ["linear", "nearest"]:
3809+
for k, v in validated_indexers.items():
3810+
obj, newidx = missing._localize(obj, {k: v})
3811+
validated_indexers[k] = newidx[k]
3812+
38073813
has_chunked_array = bool(
38083814
any(is_chunked_array(v._data) for v in obj._variables.values())
38093815
)
38103816
if has_chunked_array:
3811-
# optimization: subset to coordinate range of the target index
3812-
if method in ["linear", "nearest"]:
3813-
for k, v in validated_indexers.items():
3814-
obj, newidx = missing._localize(obj, {k: v})
3815-
validated_indexers[k] = newidx[k]
38163817
# optimization: create dask coordinate arrays once per Dataset
38173818
# rather than once per Variable when dask.array.unify_chunks is called later
38183819
# GH4739
@@ -3828,7 +3829,7 @@ def _validate_interp_indexer(x, new_x):
38283829
continue
38293830

38303831
use_indexers = (
3831-
dask_indexers if is_duck_dask_array(var.data) else validated_indexers
3832+
dask_indexers if is_duck_dask_array(var._data) else validated_indexers
38323833
)
38333834

38343835
dtype_kind = var.dtype.kind

xarray/tests/test_missing.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from __future__ import annotations
22

33
import itertools
4+
from unittest import mock
45

56
import numpy as np
67
import pandas as pd
78
import pytest
89

910
import xarray as xr
11+
from xarray.core import indexing
1012
from xarray.core.missing import (
1113
NumpyInterpolator,
1214
ScipyInterpolator,
@@ -772,3 +774,29 @@ def test_interpolators_complex_out_of_bounds():
772774
f = interpolator(xi, yi, method=method)
773775
actual = f(x)
774776
assert_array_equal(actual, expected)
777+
778+
779+
@requires_scipy
780+
def test_indexing_localize():
781+
# regression test for GH10287
782+
ds = xr.Dataset(
783+
{
784+
"sigma_a": xr.DataArray(
785+
data=np.ones((16, 8, 36811)),
786+
dims=["p", "t", "w"],
787+
coords={"w": np.linspace(0, 30000, 36811)},
788+
)
789+
}
790+
)
791+
792+
original_func = indexing.NumpyIndexingAdapter.__getitem__
793+
794+
def wrapper(self, indexer):
795+
return original_func(self, indexer)
796+
797+
with mock.patch.object(
798+
indexing.NumpyIndexingAdapter, "__getitem__", side_effect=wrapper, autospec=True
799+
) as mock_func:
800+
ds["sigma_a"].interp(w=15000.5)
801+
actual_indexer = mock_func.mock_calls[0].args[1]._key
802+
assert actual_indexer == (slice(None), slice(None), slice(18404, 18408))

0 commit comments

Comments
 (0)