Skip to content

Commit a5e1854

Browse files
committed
Clear up broadcasting
1 parent be5c783 commit a5e1854

File tree

1 file changed

+38
-24
lines changed

1 file changed

+38
-24
lines changed

xarray/core/missing.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import datetime as dt
44
import itertools
55
import warnings
6+
from collections import ChainMap
67
from collections.abc import Callable, Generator, Hashable, Sequence
78
from functools import partial
89
from numbers import Number
@@ -710,59 +711,66 @@ def interpolate_variable(
710711
func, kwargs = _get_interpolator_nd(method, **kwargs)
711712

712713
in_coords, result_coords = zip(*(v for v in indexes_coords.values()), strict=True)
713-
# broadcast out manually to minize confusing behaviour
714-
broadcast_result_coords = broadcast_variables(*result_coords)
715-
result_dims = broadcast_result_coords[0].dims
716714

717715
# input coordinates along which we are interpolation are core dimensions
718716
# the corresponding output coordinates may or may not have the same name,
719717
# so `all_in_core_dims` is also `exclude_dims`
720718
all_in_core_dims = set(indexes_coords)
721719

720+
result_dims = OrderedSet(itertools.chain(*(_.dims for _ in result_coords)))
721+
result_sizes = ChainMap(*(_.sizes for _ in result_coords))
722+
722723
# any dimensions on the output that are present on the input, but are not being
723-
# interpolated along are broadcast or loop dimensions along which we automatically
724-
# vectorize. Consider the problem in
725-
# https://github.com/pydata/xarray/issues/6799#issuecomment-2474126217
724+
# interpolated along are dimensions along which we automatically vectorize.
725+
# Consider the problem in https://github.com/pydata/xarray/issues/6799#issuecomment-2474126217
726726
# In the following, dimension names are listed out in [].
727727
# # da[time, q, lat, lon].interp(q=bar[lat,lon]). Here `lat`, `lon`
728-
# are input dimensions, present on the output, along which we vectorize.
729-
# We track these as "result broadcast dimensions".
728+
# are input dimensions, present on the output, but are not the coordinates
729+
# we are explicitly interpolating. These are the dimensions along which we vectorize.
730730
# `q` is the only input core dimensions, and changes size (disappears)
731731
# so it is in exclude_dims.
732-
result_broadcast_dims = set(
733-
itertools.chain(dim for dim in result_dims if dim not in all_in_core_dims)
734-
)
732+
vectorize_dims = (result_dims - all_in_core_dims) & set(var.dims)
735733

736734
# remove any output broadcast dimensions from the list of core dimensions
737-
output_core_dims = tuple(d for d in result_dims if d not in result_broadcast_dims)
735+
output_core_dims = tuple(d for d in result_dims if d not in vectorize_dims)
738736
input_core_dims = (
739737
# all coordinates on the input that we interpolate along
740738
[tuple(indexes_coords)]
741739
# the input coordinates are always 1D at the moment, so we just need to list out their names
742740
+ [tuple(_.dims) for _ in in_coords]
743741
# The last set of inputs are the coordinates we are interpolating to.
744-
# These have been broadcast already for ease.
745-
+ [output_core_dims] * len(result_coords)
742+
+ [
743+
tuple(d for d in coord.dims if d not in vectorize_dims)
744+
for coord in result_coords
745+
]
746746
)
747-
output_sizes = {k: broadcast_result_coords[0].sizes[k] for k in output_core_dims}
747+
output_sizes = {k: result_sizes[k] for k in output_core_dims}
748748

749749
# scipy.interpolate.interp1d always forces to float.
750750
dtype = float if not issubclass(var.dtype.type, np.inexact) else var.dtype
751751
result = apply_ufunc(
752752
_interpnd,
753753
var,
754754
*in_coords,
755-
*broadcast_result_coords,
755+
*result_coords,
756756
input_core_dims=input_core_dims,
757757
output_core_dims=[output_core_dims],
758758
exclude_dims=all_in_core_dims,
759759
dask="parallelized",
760-
kwargs=dict(interp_func=func, interp_kwargs=kwargs),
760+
kwargs=dict(
761+
interp_func=func,
762+
interp_kwargs=kwargs,
763+
# we leave broadcasting up to dask if possible
764+
# but we need broadcasted values in _interpnd, so propagate that
765+
# context (dimension names), and broadcast there
766+
# This would be unnecessary if we could tell apply_ufunc
767+
# to insert size-1 broadcast dimensions
768+
result_coord_core_dims=input_core_dims[-len(result_coords) :],
769+
),
761770
# TODO: deprecate and have the user rechunk themselves
762771
dask_gufunc_kwargs=dict(output_sizes=output_sizes, allow_rechunk=True),
763772
output_dtypes=[dtype],
764-
# if there are any broadcast dims on the result, we must vectorize on them
765-
vectorize=bool(result_broadcast_dims),
773+
vectorize=bool(vectorize_dims),
766774
keep_attrs=True,
767775
)
768776
return result
@@ -787,7 +795,11 @@ def _interp1d(
787795

788796

789797
def _interpnd(
790-
data: np.ndarray, *coords: np.ndarray, interp_func: InterpCallable, interp_kwargs
798+
data: np.ndarray,
799+
*coords: np.ndarray,
800+
interp_func: InterpCallable,
801+
interp_kwargs,
802+
result_coord_core_dims,
791803
) -> np.ndarray:
792804
"""
793805
Core nD array interpolation routine.
@@ -801,10 +813,12 @@ def _interpnd(
801813
# Convert everything to Variables, since that makes applying
802814
# `_localize` and `_floatize_x` much easier
803815
x = [Variable([f"dim_{nconst + dim}"], _x) for dim, _x in enumerate(coords[:n_x])]
804-
new_x = [
805-
Variable([f"dim_{ndim + dim}" for dim in range(_x.ndim)], _x)
806-
for _x in coords[n_x:]
807-
]
816+
new_x = broadcast_variables(
817+
*(
818+
Variable(dims, _x)
819+
for dims, _x in zip(result_coord_core_dims, coords[n_x:], strict=True)
820+
)
821+
)
808822
var = Variable([f"dim_{dim}" for dim in range(ndim)], data)
809823

810824
if interp_kwargs.get("method") in ["linear", "nearest"]:

0 commit comments

Comments
 (0)