Skip to content

interp performance with chunked dimensions #6799

Open
@slevang

Description

@slevang

What is your issue?

I'm trying to perform 2D interpolation on a large 3D array that is heavily chunked along the interpolation dimensions and not the third dimension. The application could be extracting a timeseries from a reanalysis dataset chunked in space but not time, to compare to observed station data with more precise coordinates.

I use the advanced interpolation method as described in the documentation, with the interpolation coordinates specified by DataArray's with a shared dimension like so:

%load_ext memory_profiler
import numpy as np
import dask.array as da
import xarray as xr

# Synthetic dataset chunked in the two interpolation dimensions
nt = 40000
nx = 200
ny = 200
ds = xr.Dataset(
    data_vars = {
        'foo':(
            ('t', 'x', 'y'), 
            da.random.random(size=(nt, nx, ny), chunks=(-1, 10, 10))),
    },
    coords = {
        't': np.linspace(0, 1, nt),
        'x': np.linspace(0, 1, nx),
        'y': np.linspace(0, 1, ny),
    }
)

# Interpolate to some random 2D locations
ni = 10
xx = xr.DataArray(np.random.random(ni), dims='z', name='x')
yy = xr.DataArray(np.random.random(ni), dims='z', name='y')
interpolated = ds.foo.interp(x=xx, y=yy)
%memit interpolated.compute()

With just 10 interpolation points, this example calculation uses about 1.5 * ds.nbytes of memory, and saturates around 2 * ds.nbytes by about 100 interpolation points.

This could definitely work better, as each interpolated point usually only requires a single chunk of the input dataset, and at most 4 if it is right on the corner of a chunk. For example we can instead do it in a loop and get very reasonable memory usage, but this isn't very scalable:

interpolated = []
for n in range(ni):
    interpolated.append(ds.foo.interp(x=xx.isel(z=n), y=yy.isel(z=n)))
interpolated = xr.concat(interpolated, dim='z')
%memit interpolated.compute()

I tried adding a .chunk({'z':1}) to the interpolation coordinates but this doesn't help. We can also do .sel(x=xx, y=yy, method='nearest') with very good performance.

Any tips to make this calculation work better with existing options, or otherwise ways we might improve the interp method to handle this case? Given the performance behavior I'm guessing we may be doing sequntial interpolation for the dimensions, basically an interp1d call for all the xx points and from there another to the yy points, which for even a small number of points would require nearly all chunks to be loaded in. But I haven't explored the code enough yet to understand the details.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions