Skip to content

Commit

Permalink
Add spline_filter and spline_filter1d (#215)
Browse files Browse the repository at this point in the history
* add spline_filter and spline_filter1d to ndinterp

* call da.from_array on the input if it is not already a Dask array

* support use of prefilter kwarg for affine_transform

exclude use of boundary modes that require additional padding

* test affine_transform with additional prefilter and mode combinations

* avoid warnings about deprecated import from collections

* fix copy/paste error in _dispatch_ndinterp.py

* minor cleanup in test_affine_transformation.py

* TST: add spline_filter tests

* add order-dependent default depth

The required depth is related to the magnitude of the spline prefilter poles

* TST: test output argument of spline_filter functions

* DOC: update coverage.rst

* TST: fix test_affine_transform_prefilter_modes

input_output_shape_per_dim requires sizes larger than the depth used by spline_filter

* add scipy and dask version-dependent skip

older dask didn't automatically rechunk during map_overlap

fix veriable names:
_supported_periodic_modes -> _supported_prefilter_modes
_unsupported_periodic_modes -> _unsupported_prefilter_modes

* add one more scipy version-dependent skip

* TST: improve coverage by testing corner cases

* enable mode='grid-wrap' for spline_filter

* rename temporary variable to _depth to avoid shadowing function argument

* fix typo in test_spline_filter_array_output_unsupported

* fix: allow grid-warp in spline_filter1d as well

* pep8 fixes

* more informative mode-related error messages

* update scipy.ndimage import style
  • Loading branch information
grlee77 authored May 19, 2021
1 parent 329afe7 commit bbe73c6
Show file tree
Hide file tree
Showing 6 changed files with 575 additions and 51 deletions.
42 changes: 42 additions & 0 deletions dask_image/dispatch/_dispatch_ndinterp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

dispatch_affine_transform = Dispatcher(name="dispatch_affine_transform")


# ================== affine_transform ==================
@dispatch_affine_transform.register(np.ndarray)
def numpy_affine_transform(*args, **kwargs):
Expand All @@ -30,8 +31,49 @@ def cupy_affine_transform(*args, **kwargs):
return cupyx.scipy.ndimage.affine_transform


dispatch_spline_filter = Dispatcher(name="dispatch_spline_filter")


# ================== spline_filter ==================
@dispatch_spline_filter.register(np.ndarray)
def numpy_spline_filter(*args, **kwargs):
return ndimage.spline_filter


@dispatch_spline_filter.register_lazy("cupy")
def register_cupy_spline_filter():
import cupy
import cupyx.scipy.ndimage

@dispatch_spline_filter.register(cupy.ndarray)
def cupy_spline_filter(*args, **kwargs):

return cupyx.scipy.ndimage.spline_filter


dispatch_spline_filter1d = Dispatcher(name="dispatch_spline_filter1d")


# ================== spline_filter1d ==================
@dispatch_spline_filter1d.register(np.ndarray)
def numpy_spline_filter1d(*args, **kwargs):
return ndimage.spline_filter1d


@dispatch_spline_filter1d.register_lazy("cupy")
def register_cupy_spline_filter1d():
import cupy
import cupyx.scipy.ndimage

@dispatch_spline_filter1d.register(cupy.ndarray)
def cupy_spline_filter1d(*args, **kwargs):

return cupyx.scipy.ndimage.spline_filter1d


dispatch_asarray = Dispatcher(name="dispatch_asarray")


# ===================== asarray ========================
@dispatch_asarray.register(np.ndarray)
def numpy_asarray(*args, **kwargs):
Expand Down
14 changes: 7 additions & 7 deletions dask_image/ndfilters/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import division

import collections
import collections.abc
import inspect
import numbers
import re
Expand Down Expand Up @@ -61,13 +61,13 @@ def _get_depth_boundary(ndim, depth, boundary=None):

if isinstance(depth, numbers.Number):
depth = ndim * (depth,)
if not isinstance(depth, collections.Sized):
if not isinstance(depth, collections.abc.Sized):
raise TypeError("Unexpected type for `depth`.")
if len(depth) != ndim:
raise ValueError("Expected `depth` to have a length equal to `ndim`.")
if isinstance(depth, collections.Sequence):
if isinstance(depth, collections.abc.Sequence):
depth = dict(zip(range(ndim), depth))
if not isinstance(depth, collections.Mapping):
if not isinstance(depth, collections.abc.Mapping):
raise TypeError("Unexpected type for `depth`.")

if not all(map(lambda d: isinstance(d, numbers.Integral), depth.values())):
Expand All @@ -79,15 +79,15 @@ def _get_depth_boundary(ndim, depth, boundary=None):

if (boundary is None) or isinstance(boundary, strlike):
boundary = ndim * (boundary,)
if not isinstance(boundary, collections.Sized):
if not isinstance(boundary, collections.abc.Sized):
raise TypeError("Unexpected type for `boundary`.")
if len(boundary) != ndim:
raise ValueError(
"Expected `boundary` to have a length equal to `ndim`."
)
if isinstance(boundary, collections.Sequence):
if isinstance(boundary, collections.abc.Sequence):
boundary = dict(zip(range(ndim), boundary))
if not isinstance(boundary, collections.Mapping):
if not isinstance(boundary, collections.abc.Mapping):
raise TypeError("Unexpected type for `boundary`.")

type_check = lambda b: (b is None) or isinstance(b, strlike) # noqa: E731
Expand Down
186 changes: 166 additions & 20 deletions dask_image/ndinterp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
# -*- coding: utf-8 -*-

import functools
import math
from itertools import product
import numpy as np

import dask.array as da
import numpy as np
from dask.base import tokenize
from dask.highlevelgraph import HighLevelGraph

import scipy
from scipy.ndimage import affine_transform as ndimage_affine_transform
import warnings


from ..dispatch._dispatch_ndinterp import (
dispatch_affine_transform,
dispatch_asarray,
dispatch_spline_filter,
dispatch_spline_filter1d,
)
from ..ndfilters._utils import _get_depth_boundary


__all__ = [
Expand Down Expand Up @@ -105,21 +108,28 @@ def affine_transform(
offset = matrix[:image.ndim, image.ndim]
matrix = matrix[:image.ndim, :image.ndim]

cval = kwargs.pop('cval', 0)
mode = kwargs.pop('mode', 'constant')
prefilter = kwargs.pop('prefilter', False)

supported_modes = ['constant', 'nearest']
if scipy.__version__ > np.lib.NumpyVersion('1.6.0'):
supported_modes += ['grid-constant']
if mode in ['wrap', 'reflect', 'mirror', 'grid-mirror', 'grid-wrap']:
raise NotImplementedError(
f"Mode {mode} is not currently supported. It must be one of "
f"{supported_modes}.")

# process kwargs
# prefilter is not yet supported
if 'prefilter' in kwargs:
if kwargs['prefilter'] and order > 1:
warnings.warn('Currently, `dask_image.ndinterp.affine_transform` '
'doesn\'t support `prefilter=True`. Proceeding with'
' `prefilter=False`, which if order > 1 can lead '
'to the output containing more blur than with '
'prefiltering.', UserWarning)
del kwargs['prefilter']

if 'mode' in kwargs:
if kwargs['mode'] in ['wrap', 'reflect', 'mirror']:
raise(NotImplementedError("Mode %s is not currently supported."
% kwargs['mode']))
if prefilter and order > 1:
# prefilter is not yet supported for all modes
if mode in ['nearest', 'grid-constant']:
raise NotImplementedError(
f"order > 1 with mode='{mode}' is not supported. Currently "
f"prefilter is only supported with mode='constant'."
)
image = spline_filter(image, order, output=np.float64,
mode=mode)

n = image.ndim
image_shape = image.shape
Expand Down Expand Up @@ -212,8 +222,8 @@ def affine_transform(
tuple(out_chunk_shape), # output_shape
None, # out
order,
'constant' if 'mode' not in kwargs else kwargs['mode'],
0. if 'cval' not in kwargs else kwargs['cval'],
mode,
cval,
False # prefilter
)

Expand All @@ -232,3 +242,139 @@ def affine_transform(
meta=meta)

return transformed


# magnitude of the maximum filter pole for each order
# (obtained from scipy/ndimage/src/ni_splines.c)
_maximum_pole = {
2: 0.171572875253809902396622551580603843,
3: 0.267949192431122706472553658494127633,
4: 0.361341225900220177092212841325675255,
5: 0.430575347099973791851434783493520110,
}


def _get_default_depth(order, tol=1e-8):
"""Determine the approximate depth needed for a given tolerance.
Here depth is chosen as the smallest integer such that ``|p| ** n < tol``
where `|p|` is the magnitude of the largest pole in the IIR filter.
"""
return math.ceil(np.log(tol) / np.log(_maximum_pole[order]))


def spline_filter(
image,
order=3,
output=np.float64,
mode='mirror',
output_chunks=None,
*,
depth=None,
**kwargs
):

if not type(image) == da.core.Array:
image = da.from_array(image)

# use dispatching mechanism to determine backend
spline_filter_method = dispatch_spline_filter(image)

try:
dtype = np.dtype(output)
except TypeError: # pragma: no cover
raise TypeError( # pragma: no cover
"Could not coerce the provided output to a dtype. "
"Passing array to output is not currently supported."
)

if depth is None:
depth = _get_default_depth(order)

if mode == 'wrap':
raise NotImplementedError(
"mode='wrap' is unsupported. It is recommended to use 'grid-wrap' "
"instead."
)

# Note: depths of 12 and 24 give results matching SciPy to approximately
# single and double precision accuracy, respectively.
boundary = "periodic" if mode == 'grid-wrap' else "none"
depth, boundary = _get_depth_boundary(image.ndim, depth, boundary)

# cannot pass a func kwarg named "output" to map_overlap
spline_filter_method = functools.partial(spline_filter_method,
output=dtype)

result = image.map_overlap(
spline_filter_method,
depth=depth,
boundary=boundary,
dtype=dtype,
meta=image._meta,
# spline_filter kwargs
order=order,
mode=mode,
)

return result


def spline_filter1d(
image,
order=3,
axis=-1,
output=np.float64,
mode='mirror',
output_chunks=None,
*,
depth=None,
**kwargs
):

if not type(image) == da.core.Array:
image = da.from_array(image)

# use dispatching mechanism to determine backend
spline_filter1d_method = dispatch_spline_filter1d(image)

try:
dtype = np.dtype(output)
except TypeError: # pragma: no cover
raise TypeError( # pragma: no cover
"Could not coerce the provided output to a dtype. "
"Passing array to output is not currently supported."
)

if depth is None:
depth = _get_default_depth(order)

# use depth 0 on all axes except the filtered axis
if not np.isscalar(depth):
raise ValueError("depth must be a scalar value")
depths = [0] * image.ndim
depths[axis] = depth

if mode == 'wrap':
raise NotImplementedError(
"mode='wrap' is unsupported. It is recommended to use 'grid-wrap' "
"instead."
)

# cannot pass a func kwarg named "output" to map_overlap
spline_filter1d_method = functools.partial(spline_filter1d_method,
output=dtype)

result = image.map_overlap(
spline_filter1d_method,
depth=tuple(depths),
boundary="periodic" if mode == 'grid-wrap' else "none",
dtype=dtype,
meta=image._meta,
# spline_filter1d kwargs
order=order,
axis=axis,
mode=mode,
)

return result
12 changes: 6 additions & 6 deletions docs/coverage.rst
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
*****************
Function Coverage
Function Coverage
*****************

Coverage of dask-image vs scipy ndimage functions
*************************************************

This table shows which SciPy ndimage functions are supported by dask-image.
This table shows which SciPy ndimage functions are supported by dask-image.

.. list-table::
.. list-table::
:widths: 25 25 50
:header-rows: 0

Expand Down Expand Up @@ -205,10 +205,10 @@ This table shows which SciPy ndimage functions are supported by dask-image.
- ✓
* - ``spline_filter``
- ✓
-
-
* - ``spline_filter1d``
- ✓
-
-
* - ``standard_deviation``
- ✓
- ✓
Expand All @@ -233,4 +233,4 @@ This table shows which SciPy ndimage functions are supported by dask-image.
* - ``zoom``
- ✓
-

Loading

0 comments on commit bbe73c6

Please sign in to comment.