Skip to content

Commit 245697e

Browse files
committed
Clear up broadcasting
1 parent be5c783 commit 245697e

File tree

1 file changed

+36
-20
lines changed

1 file changed

+36
-20
lines changed

xarray/core/missing.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import datetime as dt
44
import itertools
55
import warnings
6+
from collections import ChainMap
67
from collections.abc import Callable, Generator, Hashable, Sequence
78
from functools import partial
89
from numbers import Number
@@ -710,59 +711,68 @@ def interpolate_variable(
710711
func, kwargs = _get_interpolator_nd(method, **kwargs)
711712

712713
in_coords, result_coords = zip(*(v for v in indexes_coords.values()), strict=True)
713-
# broadcast out manually to minize confusing behaviour
714-
broadcast_result_coords = broadcast_variables(*result_coords)
715-
result_dims = broadcast_result_coords[0].dims
716714

717715
# input coordinates along which we are interpolation are core dimensions
718716
# the corresponding output coordinates may or may not have the same name,
719717
# so `all_in_core_dims` is also `exclude_dims`
720718
all_in_core_dims = set(indexes_coords)
721719

720+
result_dims = OrderedSet(itertools.chain(*(_.dims for _ in result_coords)))
721+
result_sizes = ChainMap(*(_.sizes for _ in result_coords))
722+
722723
# any dimensions on the output that are present on the input, but are not being
723724
# interpolated along are broadcast or loop dimensions along which we automatically
724725
# vectorize. Consider the problem in
725726
# https://github.com/pydata/xarray/issues/6799#issuecomment-2474126217
726727
# In the following, dimension names are listed out in [].
727728
# # da[time, q, lat, lon].interp(q=bar[lat,lon]). Here `lat`, `lon`
728-
# are input dimensions, present on the output, along which we vectorize.
729+
# are input dimensions, present on the output, but are not the coordinates
730+
# we are explicitly interpolating. These are the dimensions along which we vectorize.
729731
# We track these as "result broadcast dimensions".
730732
# `q` is the only input core dimensions, and changes size (disappears)
731733
# so it is in exclude_dims.
732-
result_broadcast_dims = set(
733-
itertools.chain(dim for dim in result_dims if dim not in all_in_core_dims)
734-
)
734+
vectorize_dims = (result_dims - all_in_core_dims) & set(var.dims)
735735

736736
# remove any output broadcast dimensions from the list of core dimensions
737-
output_core_dims = tuple(d for d in result_dims if d not in result_broadcast_dims)
737+
output_core_dims = tuple(d for d in result_dims if d not in vectorize_dims)
738738
input_core_dims = (
739739
# all coordinates on the input that we interpolate along
740740
[tuple(indexes_coords)]
741741
# the input coordinates are always 1D at the moment, so we just need to list out their names
742742
+ [tuple(_.dims) for _ in in_coords]
743743
# The last set of inputs are the coordinates we are interpolating to.
744-
# These have been broadcast already for ease.
745-
+ [output_core_dims] * len(result_coords)
744+
+ [
745+
tuple(d for d in coord.dims if d not in vectorize_dims)
746+
for coord in result_coords
747+
]
746748
)
747-
output_sizes = {k: broadcast_result_coords[0].sizes[k] for k in output_core_dims}
749+
output_sizes = {k: result_sizes[k] for k in output_core_dims}
748750

749751
# scipy.interpolate.interp1d always forces to float.
750752
dtype = float if not issubclass(var.dtype.type, np.inexact) else var.dtype
751753
result = apply_ufunc(
752754
_interpnd,
753755
var,
754756
*in_coords,
755-
*broadcast_result_coords,
757+
*result_coords,
756758
input_core_dims=input_core_dims,
757759
output_core_dims=[output_core_dims],
758760
exclude_dims=all_in_core_dims,
759761
dask="parallelized",
760-
kwargs=dict(interp_func=func, interp_kwargs=kwargs),
762+
kwargs=dict(
763+
interp_func=func,
764+
interp_kwargs=kwargs,
765+
# we leave broadcasting up to dask if possible
766+
# but we need broadcasted values in _interpnd, so propagate that
767+
# context (dimension names), and broadcast there
768+
# This would be unnecessary if we could tell apply_ufunc
769+
# to insert size-1 broadcast dimensions
770+
result_coord_core_dims=input_core_dims[-len(result_coords) :],
771+
),
761772
# TODO: deprecate and have the user rechunk themselves
762773
dask_gufunc_kwargs=dict(output_sizes=output_sizes, allow_rechunk=True),
763774
output_dtypes=[dtype],
764-
# if there are any broadcast dims on the result, we must vectorize on them
765-
vectorize=bool(result_broadcast_dims),
775+
vectorize=bool(vectorize_dims),
766776
keep_attrs=True,
767777
)
768778
return result
@@ -787,7 +797,11 @@ def _interp1d(
787797

788798

789799
def _interpnd(
790-
data: np.ndarray, *coords: np.ndarray, interp_func: InterpCallable, interp_kwargs
800+
data: np.ndarray,
801+
*coords: np.ndarray,
802+
interp_func: InterpCallable,
803+
interp_kwargs,
804+
result_coord_core_dims,
791805
) -> np.ndarray:
792806
"""
793807
Core nD array interpolation routine.
@@ -801,10 +815,12 @@ def _interpnd(
801815
# Convert everything to Variables, since that makes applying
802816
# `_localize` and `_floatize_x` much easier
803817
x = [Variable([f"dim_{nconst + dim}"], _x) for dim, _x in enumerate(coords[:n_x])]
804-
new_x = [
805-
Variable([f"dim_{ndim + dim}" for dim in range(_x.ndim)], _x)
806-
for _x in coords[n_x:]
807-
]
818+
new_x = broadcast_variables(
819+
*(
820+
Variable(dims, _x)
821+
for dims, _x in zip(result_coord_core_dims, coords[n_x:], strict=True)
822+
)
823+
)
808824
var = Variable([f"dim_{dim}" for dim in range(ndim)], data)
809825

810826
if interp_kwargs.get("method") in ["linear", "nearest"]:

0 commit comments

Comments
 (0)