Skip to content

Commit 7432032

Browse files
dcherianIllviljanmax-sixtyshoyer
authored
Faster interp (#4740)
Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Co-authored-by: Stephan Hoyer <shoyer@gmail.com>
1 parent 8101b8d commit 7432032

File tree

3 files changed

+44
-14
lines changed

3 files changed

+44
-14
lines changed

doc/whats-new.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,15 @@ New Features
102102
expand, ``False`` to always collapse, or ``default`` to expand unless over a
103103
pre-defined limit (:pull:`5126`).
104104
By `Tom White <https://github.com/tomwhite>`_.
105+
- Significant speedups in :py:meth:`Dataset.interp` and :py:meth:`DataArray.interp`.
106+
(:issue:`4739`, :pull:`4740`). By `Deepak Cherian <https://github.com/dcherian>`_.
107+
- Prevent passing `concat_dim` to :py:func:`xarray.open_mfdataset` when
108+
`combine='by_coords'` is specified, which should never have been possible (as
109+
:py:func:`xarray.combine_by_coords` has no `concat_dim` argument to pass to).
110+
Also removes unneeded internal reordering of datasets in
111+
:py:func:`xarray.open_mfdataset` when `combine='by_coords'` is specified.
112+
Fixes (:issue:`5230`).
113+
By `Tom Nicholas <https://github.com/TomNicholas>`_.
105114
- Implement ``__setitem__`` for :py:class:`core.indexing.DaskIndexingAdapter` if
106115
dask version supports item assignment. (:issue:`5171`, :pull:`5174`)
107116
By `Tammas Loughran <https://github.com/tammasloughran>`_.

xarray/core/dataset.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2992,17 +2992,38 @@ def _validate_interp_indexer(x, new_x):
29922992
)
29932993
return x, new_x
29942994

2995+
validated_indexers = {
2996+
k: _validate_interp_indexer(maybe_variable(obj, k), v)
2997+
for k, v in indexers.items()
2998+
}
2999+
3000+
# optimization: subset to coordinate range of the target index
3001+
if method in ["linear", "nearest"]:
3002+
for k, v in validated_indexers.items():
3003+
obj, newidx = missing._localize(obj, {k: v})
3004+
validated_indexers[k] = newidx[k]
3005+
3006+
# optimization: create dask coordinate arrays once per Dataset
3007+
# rather than once per Variable when dask.array.unify_chunks is called later
3008+
# GH4739
3009+
if obj.__dask_graph__():
3010+
dask_indexers = {
3011+
k: (index.to_base_variable().chunk(), dest.to_base_variable().chunk())
3012+
for k, (index, dest) in validated_indexers.items()
3013+
}
3014+
29953015
variables: Dict[Hashable, Variable] = {}
29963016
for name, var in obj._variables.items():
29973017
if name in indexers:
29983018
continue
29993019

3020+
if is_duck_dask_array(var.data):
3021+
use_indexers = dask_indexers
3022+
else:
3023+
use_indexers = validated_indexers
3024+
30003025
if var.dtype.kind in "uifc":
3001-
var_indexers = {
3002-
k: _validate_interp_indexer(maybe_variable(obj, k), v)
3003-
for k, v in indexers.items()
3004-
if k in var.dims
3005-
}
3026+
var_indexers = {k: v for k, v in use_indexers.items() if k in var.dims}
30063027
variables[name] = missing.interp(var, var_indexers, method, **kwargs)
30073028
elif all(d not in indexers for d in var.dims):
30083029
# keep unrelated object array

xarray/core/missing.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -617,10 +617,6 @@ def interp(var, indexes_coords, method, **kwargs):
617617
for indexes_coords in decompose_interp(indexes_coords):
618618
var = result
619619

620-
# simple speed up for the local interpolation
621-
if method in ["linear", "nearest"]:
622-
var, indexes_coords = _localize(var, indexes_coords)
623-
624620
# target dimensions
625621
dims = list(indexes_coords)
626622
x, new_x = zip(*[indexes_coords[d] for d in dims])
@@ -691,21 +687,22 @@ def interp_func(var, x, new_x, method, kwargs):
691687
if is_duck_dask_array(var):
692688
import dask.array as da
693689

694-
nconst = var.ndim - len(x)
690+
ndim = var.ndim
691+
nconst = ndim - len(x)
695692

696-
out_ind = list(range(nconst)) + list(range(var.ndim, var.ndim + new_x[0].ndim))
693+
out_ind = list(range(nconst)) + list(range(ndim, ndim + new_x[0].ndim))
697694

698695
# blockwise args format
699696
x_arginds = [[_x, (nconst + index,)] for index, _x in enumerate(x)]
700697
x_arginds = [item for pair in x_arginds for item in pair]
701698
new_x_arginds = [
702-
[_x, [var.ndim + index for index in range(_x.ndim)]] for _x in new_x
699+
[_x, [ndim + index for index in range(_x.ndim)]] for _x in new_x
703700
]
704701
new_x_arginds = [item for pair in new_x_arginds for item in pair]
705702

706703
args = (
707704
var,
708-
range(var.ndim),
705+
range(ndim),
709706
*x_arginds,
710707
*new_x_arginds,
711708
)
@@ -717,7 +714,7 @@ def interp_func(var, x, new_x, method, kwargs):
717714
new_x = rechunked[1 + (len(rechunked) - 1) // 2 :]
718715

719716
new_axes = {
720-
var.ndim + i: new_x[0].chunks[i]
717+
ndim + i: new_x[0].chunks[i]
721718
if new_x[0].chunks is not None
722719
else new_x[0].shape[i]
723720
for i in range(new_x[0].ndim)
@@ -743,6 +740,9 @@ def interp_func(var, x, new_x, method, kwargs):
743740
concatenate=True,
744741
dtype=dtype,
745742
new_axes=new_axes,
743+
# TODO: uncomment when min dask version is > 2.15
744+
# meta=var._meta,
745+
align_arrays=False,
746746
)
747747

748748
return _interpnd(var, x, new_x, func, kwargs)

0 commit comments

Comments
 (0)