3
3
import datetime as dt
4
4
import itertools
5
5
import warnings
6
+ from collections import ChainMap
6
7
from collections .abc import Callable , Generator , Hashable , Sequence
7
8
from functools import partial
8
9
from numbers import Number
@@ -710,59 +711,68 @@ def interpolate_variable(
710
711
func , kwargs = _get_interpolator_nd (method , ** kwargs )
711
712
712
713
in_coords , result_coords = zip (* (v for v in indexes_coords .values ()), strict = True )
713
- # broadcast out manually to minize confusing behaviour
714
- broadcast_result_coords = broadcast_variables (* result_coords )
715
- result_dims = broadcast_result_coords [0 ].dims
716
714
717
715
# input coordinates along which we are interpolation are core dimensions
718
716
# the corresponding output coordinates may or may not have the same name,
719
717
# so `all_in_core_dims` is also `exclude_dims`
720
718
all_in_core_dims = set (indexes_coords )
721
719
720
+ result_dims = OrderedSet (itertools .chain (* (_ .dims for _ in result_coords )))
721
+ result_sizes = ChainMap (* (_ .sizes for _ in result_coords ))
722
+
722
723
# any dimensions on the output that are present on the input, but are not being
723
724
# interpolated along are broadcast or loop dimensions along which we automatically
724
725
# vectorize. Consider the problem in
725
726
# https://github.com/pydata/xarray/issues/6799#issuecomment-2474126217
726
727
# In the following, dimension names are listed out in [].
727
728
# # da[time, q, lat, lon].interp(q=bar[lat,lon]). Here `lat`, `lon`
728
- # are input dimensions, present on the output, along which we vectorize.
729
+ # are input dimensions, present on the output, but are not the coordinates
730
+ # we are explicitly interpolating. These are the dimensions along which we vectorize.
729
731
# We track these as "result broadcast dimensions".
730
732
# `q` is the only input core dimensions, and changes size (disappears)
731
733
# so it is in exclude_dims.
732
- result_broadcast_dims = set (
733
- itertools .chain (dim for dim in result_dims if dim not in all_in_core_dims )
734
- )
734
+ vectorize_dims = (result_dims - all_in_core_dims ) & set (var .dims )
735
735
736
736
# remove any output broadcast dimensions from the list of core dimensions
737
- output_core_dims = tuple (d for d in result_dims if d not in result_broadcast_dims )
737
+ output_core_dims = tuple (d for d in result_dims if d not in vectorize_dims )
738
738
input_core_dims = (
739
739
# all coordinates on the input that we interpolate along
740
740
[tuple (indexes_coords )]
741
741
# the input coordinates are always 1D at the moment, so we just need to list out their names
742
742
+ [tuple (_ .dims ) for _ in in_coords ]
743
743
# The last set of inputs are the coordinates we are interpolating to.
744
- # These have been broadcast already for ease.
745
- + [output_core_dims ] * len (result_coords )
744
+ + [
745
+ tuple (d for d in coord .dims if d not in vectorize_dims )
746
+ for coord in result_coords
747
+ ]
746
748
)
747
- output_sizes = {k : broadcast_result_coords [ 0 ]. sizes [k ] for k in output_core_dims }
749
+ output_sizes = {k : result_sizes [k ] for k in output_core_dims }
748
750
749
751
# scipy.interpolate.interp1d always forces to float.
750
752
dtype = float if not issubclass (var .dtype .type , np .inexact ) else var .dtype
751
753
result = apply_ufunc (
752
754
_interpnd ,
753
755
var ,
754
756
* in_coords ,
755
- * broadcast_result_coords ,
757
+ * result_coords ,
756
758
input_core_dims = input_core_dims ,
757
759
output_core_dims = [output_core_dims ],
758
760
exclude_dims = all_in_core_dims ,
759
761
dask = "parallelized" ,
760
- kwargs = dict (interp_func = func , interp_kwargs = kwargs ),
762
+ kwargs = dict (
763
+ interp_func = func ,
764
+ interp_kwargs = kwargs ,
765
+ # we leave broadcasting up to dask if possible
766
+ # but we need broadcasted values in _interpnd, so propagate that
767
+ # context (dimension names), and broadcast there
768
+ # This would be unnecessary if we could tell apply_ufunc
769
+ # to insert size-1 broadcast dimensions
770
+ result_coord_core_dims = input_core_dims [- len (result_coords ) :],
771
+ ),
761
772
# TODO: deprecate and have the user rechunk themselves
762
773
dask_gufunc_kwargs = dict (output_sizes = output_sizes , allow_rechunk = True ),
763
774
output_dtypes = [dtype ],
764
- # if there are any broadcast dims on the result, we must vectorize on them
765
- vectorize = bool (result_broadcast_dims ),
775
+ vectorize = bool (vectorize_dims ),
766
776
keep_attrs = True ,
767
777
)
768
778
return result
@@ -787,7 +797,11 @@ def _interp1d(
787
797
788
798
789
799
def _interpnd (
790
- data : np .ndarray , * coords : np .ndarray , interp_func : InterpCallable , interp_kwargs
800
+ data : np .ndarray ,
801
+ * coords : np .ndarray ,
802
+ interp_func : InterpCallable ,
803
+ interp_kwargs ,
804
+ result_coord_core_dims ,
791
805
) -> np .ndarray :
792
806
"""
793
807
Core nD array interpolation routine.
@@ -801,10 +815,12 @@ def _interpnd(
801
815
# Convert everything to Variables, since that makes applying
802
816
# `_localize` and `_floatize_x` much easier
803
817
x = [Variable ([f"dim_{ nconst + dim } " ], _x ) for dim , _x in enumerate (coords [:n_x ])]
804
- new_x = [
805
- Variable ([f"dim_{ ndim + dim } " for dim in range (_x .ndim )], _x )
806
- for _x in coords [n_x :]
807
- ]
818
+ new_x = broadcast_variables (
819
+ * (
820
+ Variable (dims , _x )
821
+ for dims , _x in zip (result_coord_core_dims , coords [n_x :], strict = True )
822
+ )
823
+ )
808
824
var = Variable ([f"dim_{ dim } " for dim in range (ndim )], data )
809
825
810
826
if interp_kwargs .get ("method" ) in ["linear" , "nearest" ]:
0 commit comments