Skip to content

Commit

Permalink
GPU support for ndmorph, binary morphological functions (#157)
Browse files Browse the repository at this point in the history
* GPU support for ndmorph subpackage

* Fix ndmorph default structure for cupy arrays
  • Loading branch information
GenevieveBuckley authored Mar 1, 2021
1 parent 91fe6e1 commit 91c4955
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 4 deletions.
64 changes: 64 additions & 0 deletions dask_image/dispatch/_dispatch_ndmorph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-

import numpy as np
import scipy.ndimage

from ._dispatcher import Dispatcher

__all__ = [
"dispatch_binary_dilation",
"dispatch_binary_erosion",
"dispatch_binary_structure",
]

dispatch_binary_dilation = Dispatcher(name="dispatch_binary_dilation")
dispatch_binary_erosion = Dispatcher(name="dispatch_binary_erosion")
dispatch_binary_structure = Dispatcher(name='dispatch_binary_structure')


# ================== binary_dilation ==================
@dispatch_binary_dilation.register(np.ndarray)
def numpy_binary_dilation(*args, **kwargs):
return scipy.ndimage.binary_dilation


@dispatch_binary_dilation.register_lazy("cupy")
def register_cupy_binary_dilation():
import cupy
import cupyx.scipy.ndimage

@dispatch_binary_dilation.register(cupy.ndarray)
def cupy_binary_dilation(*args, **kwargs):
return cupyx.scipy.ndimage.binary_dilation


# ================== binary_erosion ==================
@dispatch_binary_erosion.register(np.ndarray)
def numpy_binary_erosion(*args, **kwargs):
return scipy.ndimage.binary_erosion


@dispatch_binary_erosion.register_lazy("cupy")
def register_cupy_binary_erosion():
import cupy
import cupyx.scipy.ndimage

@dispatch_binary_erosion.register(cupy.ndarray)
def cupy_binary_erosion(*args, **kwargs):
return cupyx.scipy.ndimage.binary_erosion


# ================== generate_binary_structure ==================
@dispatch_binary_structure.register(np.ndarray)
def numpy_binary_structure(*args, **kwargs):
return scipy.ndimage.generate_binary_structure


@dispatch_binary_structure.register_lazy("cupy")
def register_cupy_binary_structure():
import cupy
import cupyx.scipy.ndimage

@dispatch_binary_structure.register(cupy.ndarray)
def cupy_binary_structure(*args, **kwargs):
return cupyx.scipy.ndimage.generate_binary_structure
14 changes: 12 additions & 2 deletions dask_image/ndmorph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@

from . import _utils
from . import _ops
from ..dispatch._dispatch_ndmorph import (
dispatch_binary_dilation,
dispatch_binary_erosion)

__all__ = [
"binary_closing",
"binary_dilation",
"binary_erosion",
"binary_opening",
]


@_utils._update_wrapper(scipy.ndimage.binary_closing)
Expand Down Expand Up @@ -43,7 +53,7 @@ def binary_dilation(image,
border_value = _utils._get_border_value(border_value)

result = _ops._binary_op(
scipy.ndimage.binary_dilation,
dispatch_binary_dilation(image),
image,
structure=structure,
iterations=iterations,
Expand All @@ -67,7 +77,7 @@ def binary_erosion(image,
border_value = _utils._get_border_value(border_value)

result = _ops._binary_op(
scipy.ndimage.binary_erosion,
dispatch_binary_erosion(image),
image,
structure=structure,
iterations=iterations,
Expand Down
6 changes: 4 additions & 2 deletions dask_image/ndmorph/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import dask.array

from ..dispatch._dispatch_ndmorph import dispatch_binary_structure
from ..ndfilters._utils import (
_update_wrapper,
_get_depth_boundary,
Expand All @@ -24,8 +25,9 @@
def _get_structure(image, structure):
# Create square connectivity as default
if structure is None:
structure = scipy.ndimage.generate_binary_structure(image.ndim, 1)
elif isinstance(structure, (numpy.ndarray, dask.array.Array)):
generate_binary_structure = dispatch_binary_structure(image)
structure = generate_binary_structure(image.ndim, 1)
elif hasattr(structure, 'ndim'):
if structure.ndim != image.ndim:
raise RuntimeError(
"`structure` must have the same rank as `image`."
Expand Down
31 changes: 31 additions & 0 deletions tests/test_dask_image/test_ndmorph/test_cupy_ndmorph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import dask.array as da
import numpy as np
import pytest

from dask_image import ndmorph

cupy = pytest.importorskip("cupy", minversion="7.7.0")


@pytest.fixture
def array():
s = (10, 10)
a = da.from_array(cupy.arange(int(np.prod(s)),
dtype=cupy.float32).reshape(s), chunks=5)
return a


@pytest.mark.cupy
@pytest.mark.parametrize("func", [
ndmorph.binary_closing,
ndmorph.binary_dilation,
ndmorph.binary_erosion,
ndmorph.binary_opening,
])
def test_cupy_ndmorph(array, func):
"""Test convolve & correlate filters with cupy input arrays."""
result = func(array)
result.compute()

0 comments on commit 91c4955

Please sign in to comment.