-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Rewrite interp to use apply_ufunc
#9881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c1603c5
6109f6e
61c5b2c
81b73b9
652bcc1
03f0b36
9b915b2
be5c783
a5e1854
eef94fa
79a7e56
6e22072
8d4503a
586f638
972e9fb
38e66fc
43a0691
97a388e
ef24840
437219f
c152ca3
6c1dd95
1b9845d
81b8a90
6f6dd0a
e6ec62b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2921,19 +2921,11 @@ def _validate_interp_indexers( | |
"""Variant of _validate_indexers to be used for interpolation""" | ||
for k, v in self._validate_indexers(indexers): | ||
if isinstance(v, Variable): | ||
if v.ndim == 1: | ||
yield k, v.to_index_variable() | ||
else: | ||
yield k, v | ||
elif isinstance(v, int): | ||
yield k, v | ||
elif is_scalar(v): | ||
yield k, Variable((), v, attrs=self.coords[k].attrs) | ||
elif isinstance(v, np.ndarray): | ||
if v.ndim == 0: | ||
yield k, Variable((), v, attrs=self.coords[k].attrs) | ||
elif v.ndim == 1: | ||
yield k, IndexVariable((k,), v, attrs=self.coords[k].attrs) | ||
else: | ||
raise AssertionError() # Already tested by _validate_indexers | ||
yield k, Variable(dims=(k,), data=v, attrs=self.coords[k].attrs) | ||
else: | ||
raise TypeError(type(v)) | ||
|
||
|
@@ -4127,18 +4119,6 @@ def interp( | |
|
||
coords = either_dict_or_kwargs(coords, coords_kwargs, "interp") | ||
indexers = dict(self._validate_interp_indexers(coords)) | ||
|
||
if coords: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handled by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For posterity the bad thing about this approach is that it can greatly expand the number of core dimensions for the problem, limiting the potential for parallelism. Consider the problem in #6799 (comment). In the following, dimension names are listed out in
|
||
# This avoids broadcasting over coordinates that are both in | ||
# the original array AND in the indexing array. It essentially | ||
# forces interpolation along the shared coordinates. | ||
sdims = ( | ||
set(self.dims) | ||
.intersection(*[set(nx.dims) for nx in indexers.values()]) | ||
.difference(coords.keys()) | ||
) | ||
indexers.update({d: self.variables[d] for d in sdims}) | ||
|
||
obj = self if assume_sorted else self.sortby(list(coords)) | ||
|
||
def maybe_variable(obj, k): | ||
|
@@ -4169,16 +4149,18 @@ def _validate_interp_indexer(x, new_x): | |
for k, v in indexers.items() | ||
} | ||
|
||
# optimization: subset to coordinate range of the target index | ||
if method in ["linear", "nearest"]: | ||
for k, v in validated_indexers.items(): | ||
obj, newidx = missing._localize(obj, {k: v}) | ||
validated_indexers[k] = newidx[k] | ||
|
||
# optimization: create dask coordinate arrays once per Dataset | ||
# rather than once per Variable when dask.array.unify_chunks is called later | ||
# GH4739 | ||
if obj.__dask_graph__(): | ||
has_chunked_array = bool( | ||
any(is_chunked_array(v._data) for v in obj._variables.values()) | ||
) | ||
if has_chunked_array: | ||
# optimization: subset to coordinate range of the target index | ||
if method in ["linear", "nearest"]: | ||
for k, v in validated_indexers.items(): | ||
obj, newidx = missing._localize(obj, {k: v}) | ||
validated_indexers[k] = newidx[k] | ||
# optimization: create dask coordinate arrays once per Dataset | ||
# rather than once per Variable when dask.array.unify_chunks is called later | ||
# GH4739 | ||
dask_indexers = { | ||
k: (index.to_base_variable().chunk(), dest.to_base_variable().chunk()) | ||
for k, (index, dest) in validated_indexers.items() | ||
|
@@ -4190,10 +4172,9 @@ def _validate_interp_indexer(x, new_x): | |
if name in indexers: | ||
continue | ||
|
||
if is_duck_dask_array(var.data): | ||
use_indexers = dask_indexers | ||
else: | ||
use_indexers = validated_indexers | ||
use_indexers = ( | ||
dask_indexers if is_duck_dask_array(var.data) else validated_indexers | ||
) | ||
|
||
dtype_kind = var.dtype.kind | ||
if dtype_kind in "uifc": | ||
|
Uh oh!
There was an error while loading. Please reload this page.