Description
What is your issue?
I'm trying to perform 2D interpolation on a large 3D array that is heavily chunked along the interpolation dimensions and not the third dimension. The application could be extracting a timeseries from a reanalysis dataset chunked in space but not time, to compare to observed station data with more precise coordinates.
I use the advanced interpolation method as described in the documentation, with the interpolation coordinates specified by DataArray's with a shared dimension like so:
%load_ext memory_profiler
import numpy as np
import dask.array as da
import xarray as xr
# Synthetic dataset chunked in the two interpolation dimensions
nt = 40000
nx = 200
ny = 200
ds = xr.Dataset(
data_vars = {
'foo':(
('t', 'x', 'y'),
da.random.random(size=(nt, nx, ny), chunks=(-1, 10, 10))),
},
coords = {
't': np.linspace(0, 1, nt),
'x': np.linspace(0, 1, nx),
'y': np.linspace(0, 1, ny),
}
)
# Interpolate to some random 2D locations
ni = 10
xx = xr.DataArray(np.random.random(ni), dims='z', name='x')
yy = xr.DataArray(np.random.random(ni), dims='z', name='y')
interpolated = ds.foo.interp(x=xx, y=yy)
%memit interpolated.compute()
With just 10 interpolation points, this example calculation uses about 1.5 * ds.nbytes
of memory, and saturates around 2 * ds.nbytes
by about 100 interpolation points.
This could definitely work better, as each interpolated point usually only requires a single chunk of the input dataset, and at most 4 if it is right on the corner of a chunk. For example we can instead do it in a loop and get very reasonable memory usage, but this isn't very scalable:
interpolated = []
for n in range(ni):
interpolated.append(ds.foo.interp(x=xx.isel(z=n), y=yy.isel(z=n)))
interpolated = xr.concat(interpolated, dim='z')
%memit interpolated.compute()
I tried adding a .chunk({'z':1})
to the interpolation coordinates but this doesn't help. We can also do .sel(x=xx, y=yy, method='nearest')
with very good performance.
Any tips to make this calculation work better with existing options, or otherwise ways we might improve the interp
method to handle this case? Given the performance behavior I'm guessing we may be doing sequntial interpolation for the dimensions, basically an interp1d
call for all the xx
points and from there another to the yy
points, which for even a small number of points would require nearly all chunks to be loaded in. But I haven't explored the code enough yet to understand the details.