3
3
import datetime as dt
4
4
import itertools
5
5
import warnings
6
+ from collections import ChainMap
6
7
from collections .abc import Callable , Generator , Hashable , Sequence
7
8
from functools import partial
8
9
from numbers import Number
@@ -710,59 +711,66 @@ def interpolate_variable(
710
711
func , kwargs = _get_interpolator_nd (method , ** kwargs )
711
712
712
713
in_coords , result_coords = zip (* (v for v in indexes_coords .values ()), strict = True )
713
- # broadcast out manually to minize confusing behaviour
714
- broadcast_result_coords = broadcast_variables (* result_coords )
715
- result_dims = broadcast_result_coords [0 ].dims
716
714
717
715
# input coordinates along which we are interpolation are core dimensions
718
716
# the corresponding output coordinates may or may not have the same name,
719
717
# so `all_in_core_dims` is also `exclude_dims`
720
718
all_in_core_dims = set (indexes_coords )
721
719
720
+ result_dims = OrderedSet (itertools .chain (* (_ .dims for _ in result_coords )))
721
+ result_sizes = ChainMap (* (_ .sizes for _ in result_coords ))
722
+
722
723
# any dimensions on the output that are present on the input, but are not being
723
- # interpolated along are broadcast or loop dimensions along which we automatically
724
- # vectorize. Consider the problem in
725
- # https://github.com/pydata/xarray/issues/6799#issuecomment-2474126217
724
+ # interpolated along are dimensions along which we automatically vectorize.
725
+ # Consider the problem in https://github.com/pydata/xarray/issues/6799#issuecomment-2474126217
726
726
# In the following, dimension names are listed out in [].
727
727
# # da[time, q, lat, lon].interp(q=bar[lat,lon]). Here `lat`, `lon`
728
- # are input dimensions, present on the output, along which we vectorize.
729
- # We track these as "result broadcast dimensions" .
728
+ # are input dimensions, present on the output, but are not the coordinates
729
+ # we are explicitly interpolating. These are the dimensions along which we vectorize .
730
730
# `q` is the only input core dimensions, and changes size (disappears)
731
731
# so it is in exclude_dims.
732
- result_broadcast_dims = set (
733
- itertools .chain (dim for dim in result_dims if dim not in all_in_core_dims )
734
- )
732
+ vectorize_dims = (result_dims - all_in_core_dims ) & set (var .dims )
735
733
736
734
# remove any output broadcast dimensions from the list of core dimensions
737
- output_core_dims = tuple (d for d in result_dims if d not in result_broadcast_dims )
735
+ output_core_dims = tuple (d for d in result_dims if d not in vectorize_dims )
738
736
input_core_dims = (
739
737
# all coordinates on the input that we interpolate along
740
738
[tuple (indexes_coords )]
741
739
# the input coordinates are always 1D at the moment, so we just need to list out their names
742
740
+ [tuple (_ .dims ) for _ in in_coords ]
743
741
# The last set of inputs are the coordinates we are interpolating to.
744
- # These have been broadcast already for ease.
745
- + [output_core_dims ] * len (result_coords )
742
+ + [
743
+ tuple (d for d in coord .dims if d not in vectorize_dims )
744
+ for coord in result_coords
745
+ ]
746
746
)
747
- output_sizes = {k : broadcast_result_coords [ 0 ]. sizes [k ] for k in output_core_dims }
747
+ output_sizes = {k : result_sizes [k ] for k in output_core_dims }
748
748
749
749
# scipy.interpolate.interp1d always forces to float.
750
750
dtype = float if not issubclass (var .dtype .type , np .inexact ) else var .dtype
751
751
result = apply_ufunc (
752
752
_interpnd ,
753
753
var ,
754
754
* in_coords ,
755
- * broadcast_result_coords ,
755
+ * result_coords ,
756
756
input_core_dims = input_core_dims ,
757
757
output_core_dims = [output_core_dims ],
758
758
exclude_dims = all_in_core_dims ,
759
759
dask = "parallelized" ,
760
- kwargs = dict (interp_func = func , interp_kwargs = kwargs ),
760
+ kwargs = dict (
761
+ interp_func = func ,
762
+ interp_kwargs = kwargs ,
763
+ # we leave broadcasting up to dask if possible
764
+ # but we need broadcasted values in _interpnd, so propagate that
765
+ # context (dimension names), and broadcast there
766
+ # This would be unnecessary if we could tell apply_ufunc
767
+ # to insert size-1 broadcast dimensions
768
+ result_coord_core_dims = input_core_dims [- len (result_coords ) :],
769
+ ),
761
770
# TODO: deprecate and have the user rechunk themselves
762
771
dask_gufunc_kwargs = dict (output_sizes = output_sizes , allow_rechunk = True ),
763
772
output_dtypes = [dtype ],
764
- # if there are any broadcast dims on the result, we must vectorize on them
765
- vectorize = bool (result_broadcast_dims ),
773
+ vectorize = bool (vectorize_dims ),
766
774
keep_attrs = True ,
767
775
)
768
776
return result
@@ -787,7 +795,11 @@ def _interp1d(
787
795
788
796
789
797
def _interpnd (
790
- data : np .ndarray , * coords : np .ndarray , interp_func : InterpCallable , interp_kwargs
798
+ data : np .ndarray ,
799
+ * coords : np .ndarray ,
800
+ interp_func : InterpCallable ,
801
+ interp_kwargs ,
802
+ result_coord_core_dims ,
791
803
) -> np .ndarray :
792
804
"""
793
805
Core nD array interpolation routine.
@@ -801,10 +813,12 @@ def _interpnd(
801
813
# Convert everything to Variables, since that makes applying
802
814
# `_localize` and `_floatize_x` much easier
803
815
x = [Variable ([f"dim_{ nconst + dim } " ], _x ) for dim , _x in enumerate (coords [:n_x ])]
804
- new_x = [
805
- Variable ([f"dim_{ ndim + dim } " for dim in range (_x .ndim )], _x )
806
- for _x in coords [n_x :]
807
- ]
816
+ new_x = broadcast_variables (
817
+ * (
818
+ Variable (dims , _x )
819
+ for dims , _x in zip (result_coord_core_dims , coords [n_x :], strict = True )
820
+ )
821
+ )
808
822
var = Variable ([f"dim_{ dim } " for dim in range (ndim )], data )
809
823
810
824
if interp_kwargs .get ("method" ) in ["linear" , "nearest" ]:
0 commit comments