Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions requirements/py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ dependencies:
- numpy
- cython

# Optional dependencies.
- dask

# Developer dependencies.
- black=21.5b0
- flake8=3.9.2
Expand Down
3 changes: 3 additions & 0 deletions requirements/py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ dependencies:
- numpy
- cython

# Optional dependencies.
- dask

# Developer dependencies.
- black=21.5b0
- flake8=3.8.2
Expand Down
61 changes: 61 additions & 0 deletions stratify/_vinterp.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# z_src - the values of Z where fz_src is defined
# z_target - the desired values of Z to generate new data for.
# fz_src - the data, defined at each z_src
import functools
import numpy as np

cimport cython
Expand Down Expand Up @@ -502,6 +503,66 @@ def interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
modes. NaN extrapolation is the default.

"""
func = functools.partial(
_interpolate,
axis=axis,
rising=rising,
interpolation=interpolation,
extrapolation=extrapolation
)
if not hasattr(fz_src, 'compute'):
# Numpy array
return func(z_target, z_src, fz_src)

# Dask array
import dask.array as da

# Ensure z_target is an array.
if not isinstance(z_target, (np.ndarray, da.Array)):
z_target = np.array(z_target)

# Compute chunk sizes
if axis < 0:
axis += fz_src.ndim
in_chunks = list(fz_src.chunks)
in_chunks[axis] = fz_src.shape[axis]

out_chunks = list(in_chunks)
if z_target.ndim == 1:
out_chunks[axis] = z_target.shape[0]
else:
out_chunks[axis] = z_target.shape[axis]

# Ensure `fz_src` is not chunked along `axis`.
fz_src = fz_src.rechunk(in_chunks)

# Ensure z_src is a dask array with the correct chunks.
if isinstance(z_src, da.Array):
z_src = z_src.rechunk(in_chunks)
else:
z_src = da.asarray(z_src, chunks=in_chunks)

# Compute with 1-dimensional target array.
if z_target.ndim == 1:
func = functools.partial(func, z_target)
return da.map_blocks(func, z_src, fz_src,
chunks=out_chunks, dtype=fz_src.dtype,
meta=np.array((), dtype=fz_src.dtype))

# Ensure z_target is a dask array with the correct chunks
if isinstance(z_target, da.Array):
z_target = z_target.rechunk(out_chunks)
else:
z_target = da.asarray(z_target, chunks=out_chunks)

# Compute with multi-dimensional target array.
return da.map_blocks(func, z_target, z_src, fz_src,
chunks=out_chunks, dtype=fz_src.dtype,
meta=np.array((), dtype=fz_src.dtype))


def _interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
interpolation='linear', extrapolation='nan'):
if interpolation in interp_schemes:
interpolation = interp_schemes[interpolation]()
if extrapolation in extrap_schemes:
Expand Down
26 changes: 20 additions & 6 deletions stratify/tests/performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,45 @@
Functions that may be used to measure performance of a component.

"""
import dask.array as da
import numpy as np

import stratify


def src_data(shape=(400, 500, 100)):
def src_data(shape=(400, 500, 100), lazy=False):
z = np.tile(np.linspace(0, 100, shape[-1]), np.prod(shape[:2])).reshape(shape)
fz = np.arange(np.prod(shape)).reshape(shape)
if lazy:
fz = da.arange(np.prod(shape), dtype=np.float64).reshape(shape)
else:
fz = np.arange(np.prod(shape), dtype=np.float64).reshape(shape)
return z, fz


def interp_and_extrap(
shape,
lazy,
interp=stratify.INTERPOLATE_LINEAR,
extrap=stratify.EXTRAPOLATE_NEAREST,
):
z, fz = src_data(shape)
stratify.interpolate(
np.linspace(-20, 120, 50),
z, fz = src_data(shape, lazy)
tgt = np.linspace(-20, 120, 50)
result = stratify.interpolate(
tgt,
z,
fz,
interpolation=interp,
extrapolation=extrap,
)
if isinstance(result, da.Array):
print("lazy calculation")
print(result.chunks)
result.compute()
else:
print("non-lazy calculation")


if __name__ == "__main__":
interp_and_extrap(shape=(500, 600, 100))
import sys
lazy = "lazy" in sys.argv[1:]
interp_and_extrap(shape=(500, 600, 100), lazy=lazy)
41 changes: 41 additions & 0 deletions stratify/tests/test_vinterp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest

import dask.array as da
import numpy as np
from numpy.testing import assert_array_almost_equal, assert_array_equal

Expand Down Expand Up @@ -46,6 +47,17 @@ def interpolate(self, x_target, x_src, rising=None):
)
assert_array_equal(r1, r2)

lazy_fx_src = da.asarray(fx_src, chunks=tuple(range(1, x_src.ndim + 1)))
r3 = stratify.interpolate(
x_target,
x_src,
lazy_fx_src,
rising=rising,
interpolation=index_interp,
extrapolation=extrap_direct,
)
assert_array_equal(r1, r3.compute())

return r1

def test_interp_only(self):
Expand Down Expand Up @@ -533,6 +545,35 @@ def test_target_z_3d_axis_0(self):
)
assert_array_equal(result, f_source)

def test_dask(self):
z_target = z_source = f_source = np.arange(3) * np.ones([4, 2, 3])
reference = vinterp.interpolate(
z_target, z_source, f_source, extrapolation="linear"
)
# Test with various combinations of lazy input
f_src = da.asarray(f_source, chunks=(2, 1, 2))
for z_tgt in (z_target, z_target.tolist(), da.asarray(z_target)):
for z_src in (z_source, da.asarray(z_source)):
result = vinterp.interpolate(
z_tgt, z_src, f_src, extrapolation="linear"
)
assert_array_equal(reference, result.compute())

def test_dask_1d_target(self):
z_target = np.array([0.5])
z_source = f_source = np.arange(3) * np.ones([4, 2, 3])
reference = vinterp.interpolate(
z_target, z_source, f_source, axis=1, extrapolation="linear"
)
# Test with various combinations of lazy input
f_src = da.asarray(f_source, chunks=(2, 1, 2))
for z_tgt in (z_target, z_target.tolist(), da.asarray(z_target)):
for z_src in (z_source, da.asarray(z_source)):
result = vinterp.interpolate(
z_tgt, z_src, f_src, axis=1, extrapolation="linear"
)
assert_array_equal(reference, result.compute())


if __name__ == "__main__":
unittest.main()