Skip to content

Commit 278c17c

Browse files
authored
Support interpolation on dask arrays (#52)
* Changes found on laptop * Add notebook * Add dask to requirements * Black * Clean up * Clean up * Take out rechunking the input array based on the size of the output array
1 parent bc6343a commit 278c17c

File tree

5 files changed

+128
-6
lines changed

5 files changed

+128
-6
lines changed

requirements/py38.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ dependencies:
1414
- numpy
1515
- cython
1616

17+
# Optional dependencies.
18+
- dask
19+
1720
# Developer dependencies.
1821
- black=21.5b0
1922
- flake8=3.9.2

requirements/py39.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ dependencies:
1414
- numpy
1515
- cython
1616

17+
# Optional dependencies.
18+
- dask
19+
1720
# Developer dependencies.
1821
- black=21.5b0
1922
- flake8=3.8.2

stratify/_vinterp.pyx

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# z_src - the values of Z where fz_src is defined
44
# z_target - the desired values of Z to generate new data for.
55
# fz_src - the data, defined at each z_src
6+
import functools
67
import numpy as np
78

89
cimport cython
@@ -502,6 +503,66 @@ def interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
502503
modes. NaN extrapolation is the default.
503504
504505
"""
506+
func = functools.partial(
507+
_interpolate,
508+
axis=axis,
509+
rising=rising,
510+
interpolation=interpolation,
511+
extrapolation=extrapolation
512+
)
513+
if not hasattr(fz_src, 'compute'):
514+
# Numpy array
515+
return func(z_target, z_src, fz_src)
516+
517+
# Dask array
518+
import dask.array as da
519+
520+
# Ensure z_target is an array.
521+
if not isinstance(z_target, (np.ndarray, da.Array)):
522+
z_target = np.array(z_target)
523+
524+
# Compute chunk sizes
525+
if axis < 0:
526+
axis += fz_src.ndim
527+
in_chunks = list(fz_src.chunks)
528+
in_chunks[axis] = fz_src.shape[axis]
529+
530+
out_chunks = list(in_chunks)
531+
if z_target.ndim == 1:
532+
out_chunks[axis] = z_target.shape[0]
533+
else:
534+
out_chunks[axis] = z_target.shape[axis]
535+
536+
# Ensure `fz_src` is not chunked along `axis`.
537+
fz_src = fz_src.rechunk(in_chunks)
538+
539+
# Ensure z_src is a dask array with the correct chunks.
540+
if isinstance(z_src, da.Array):
541+
z_src = z_src.rechunk(in_chunks)
542+
else:
543+
z_src = da.asarray(z_src, chunks=in_chunks)
544+
545+
# Compute with 1-dimensional target array.
546+
if z_target.ndim == 1:
547+
func = functools.partial(func, z_target)
548+
return da.map_blocks(func, z_src, fz_src,
549+
chunks=out_chunks, dtype=fz_src.dtype,
550+
meta=np.array((), dtype=fz_src.dtype))
551+
552+
# Ensure z_target is a dask array with the correct chunks
553+
if isinstance(z_target, da.Array):
554+
z_target = z_target.rechunk(out_chunks)
555+
else:
556+
z_target = da.asarray(z_target, chunks=out_chunks)
557+
558+
# Compute with multi-dimensional target array.
559+
return da.map_blocks(func, z_target, z_src, fz_src,
560+
chunks=out_chunks, dtype=fz_src.dtype,
561+
meta=np.array((), dtype=fz_src.dtype))
562+
563+
564+
def _interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
565+
interpolation='linear', extrapolation='nan'):
505566
if interpolation in interp_schemes:
506567
interpolation = interp_schemes[interpolation]()
507568
if extrapolation in extrap_schemes:

stratify/tests/performance.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,45 @@
22
Functions that may be used to measure performance of a component.
33
44
"""
5+
import dask.array as da
56
import numpy as np
67

78
import stratify
89

910

10-
def src_data(shape=(400, 500, 100)):
11+
def src_data(shape=(400, 500, 100), lazy=False):
1112
z = np.tile(np.linspace(0, 100, shape[-1]), np.prod(shape[:2])).reshape(shape)
12-
fz = np.arange(np.prod(shape)).reshape(shape)
13+
if lazy:
14+
fz = da.arange(np.prod(shape), dtype=np.float64).reshape(shape)
15+
else:
16+
fz = np.arange(np.prod(shape), dtype=np.float64).reshape(shape)
1317
return z, fz
1418

1519

1620
def interp_and_extrap(
1721
shape,
22+
lazy,
1823
interp=stratify.INTERPOLATE_LINEAR,
1924
extrap=stratify.EXTRAPOLATE_NEAREST,
2025
):
21-
z, fz = src_data(shape)
22-
stratify.interpolate(
23-
np.linspace(-20, 120, 50),
26+
z, fz = src_data(shape, lazy)
27+
tgt = np.linspace(-20, 120, 50)
28+
result = stratify.interpolate(
29+
tgt,
2430
z,
2531
fz,
2632
interpolation=interp,
2733
extrapolation=extrap,
2834
)
35+
if isinstance(result, da.Array):
36+
print("lazy calculation")
37+
print(result.chunks)
38+
result.compute()
39+
else:
40+
print("non-lazy calculation")
2941

3042

3143
if __name__ == "__main__":
32-
interp_and_extrap(shape=(500, 600, 100))
44+
import sys
45+
lazy = "lazy" in sys.argv[1:]
46+
interp_and_extrap(shape=(500, 600, 100), lazy=lazy)

stratify/tests/test_vinterp.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22

3+
import dask.array as da
34
import numpy as np
45
from numpy.testing import assert_array_almost_equal, assert_array_equal
56

@@ -46,6 +47,17 @@ def interpolate(self, x_target, x_src, rising=None):
4647
)
4748
assert_array_equal(r1, r2)
4849

50+
lazy_fx_src = da.asarray(fx_src, chunks=tuple(range(1, x_src.ndim + 1)))
51+
r3 = stratify.interpolate(
52+
x_target,
53+
x_src,
54+
lazy_fx_src,
55+
rising=rising,
56+
interpolation=index_interp,
57+
extrapolation=extrap_direct,
58+
)
59+
assert_array_equal(r1, r3.compute())
60+
4961
return r1
5062

5163
def test_interp_only(self):
@@ -533,6 +545,35 @@ def test_target_z_3d_axis_0(self):
533545
)
534546
assert_array_equal(result, f_source)
535547

548+
def test_dask(self):
549+
z_target = z_source = f_source = np.arange(3) * np.ones([4, 2, 3])
550+
reference = vinterp.interpolate(
551+
z_target, z_source, f_source, extrapolation="linear"
552+
)
553+
# Test with various combinations of lazy input
554+
f_src = da.asarray(f_source, chunks=(2, 1, 2))
555+
for z_tgt in (z_target, z_target.tolist(), da.asarray(z_target)):
556+
for z_src in (z_source, da.asarray(z_source)):
557+
result = vinterp.interpolate(
558+
z_tgt, z_src, f_src, extrapolation="linear"
559+
)
560+
assert_array_equal(reference, result.compute())
561+
562+
def test_dask_1d_target(self):
563+
z_target = np.array([0.5])
564+
z_source = f_source = np.arange(3) * np.ones([4, 2, 3])
565+
reference = vinterp.interpolate(
566+
z_target, z_source, f_source, axis=1, extrapolation="linear"
567+
)
568+
# Test with various combinations of lazy input
569+
f_src = da.asarray(f_source, chunks=(2, 1, 2))
570+
for z_tgt in (z_target, z_target.tolist(), da.asarray(z_target)):
571+
for z_src in (z_source, da.asarray(z_source)):
572+
result = vinterp.interpolate(
573+
z_tgt, z_src, f_src, axis=1, extrapolation="linear"
574+
)
575+
assert_array_equal(reference, result.compute())
576+
536577

537578
if __name__ == "__main__":
538579
unittest.main()

0 commit comments

Comments
 (0)