Skip to content

Commit e8af057

Browse files
authored
Improve interp performance (#7843)
1 parent 6cd6122 commit e8af057

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

xarray/core/missing.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs):
639639
var.transpose(*original_dims).data, x, destination, method, kwargs
640640
)
641641

642-
result = Variable(new_dims, interped, attrs=var.attrs)
642+
result = Variable(new_dims, interped, attrs=var.attrs, fastpath=True)
643643

644644
# dimension of the output array
645645
out_dims: OrderedSet = OrderedSet()
@@ -648,7 +648,8 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs):
648648
out_dims.update(indexes_coords[d][1].dims)
649649
else:
650650
out_dims.add(d)
651-
result = result.transpose(*out_dims)
651+
if len(out_dims) > 1:
652+
result = result.transpose(*out_dims)
652653
return result
653654

654655

@@ -709,28 +710,24 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs):
709710
]
710711
new_x_arginds = [item for pair in new_x_arginds for item in pair]
711712

712-
args = (
713-
var,
714-
range(ndim),
715-
*x_arginds,
716-
*new_x_arginds,
717-
)
713+
args = (var, range(ndim), *x_arginds, *new_x_arginds)
718714

719715
_, rechunked = chunkmanager.unify_chunks(*args)
720716

721717
args = tuple(elem for pair in zip(rechunked, args[1::2]) for elem in pair)
722718

723719
new_x = rechunked[1 + (len(rechunked) - 1) // 2 :]
724720

721+
new_x0_chunks = new_x[0].chunks
722+
new_x0_shape = new_x[0].shape
723+
new_x0_chunks_is_not_none = new_x0_chunks is not None
725724
new_axes = {
726-
ndim + i: new_x[0].chunks[i]
727-
if new_x[0].chunks is not None
728-
else new_x[0].shape[i]
725+
ndim + i: new_x0_chunks[i] if new_x0_chunks_is_not_none else new_x0_shape[i]
729726
for i in range(new_x[0].ndim)
730727
}
731728

732729
# if useful, re-use localize for each chunk of new_x
733-
localize = (method in ["linear", "nearest"]) and (new_x[0].chunks is not None)
730+
localize = (method in ["linear", "nearest"]) and new_x0_chunks_is_not_none
734731

735732
# scipy.interpolate.interp1d always forces to float.
736733
# Use the same check for blockwise as well:

0 commit comments

Comments
 (0)