Skip to content

Commit

Permalink
ndmeasure.label: fix for arbitrary struct elements (#321)
Browse files Browse the repository at this point in the history
* ndmeasure.label: fix for arbitrary struct elements

* Fix case of structure=None and add test
  • Loading branch information
m-albert authored Jul 27, 2023
1 parent 88d81c6 commit 91ea280
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 12 deletions.
59 changes: 48 additions & 11 deletions dask_image/ndmeasure/_utils/_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,11 @@ def label_adjacency_graph(labels, structure, nlabels):
This matrix has value 1 at (i, j) if label i is connected to
label j in the global volume, 0 everywhere else.
"""
faces = _chunk_faces(labels.chunks, labels.shape)

if structure is None:
structure = scipy.ndimage.generate_binary_structure(labels.ndim, 1)

faces = _chunk_faces(labels.chunks, labels.shape, structure)
all_mappings = [da.empty((2, 0), dtype=LABEL_DTYPE, chunks=1)]
for face_slice in faces:
face = labels[face_slice]
Expand All @@ -163,7 +167,7 @@ def label_adjacency_graph(labels, structure, nlabels):
return mat


def _chunk_faces(chunks, shape):
def _chunk_faces(chunks, shape, structure):
"""
Return slices for two-pixel-wide boundaries between chunks.
Expand All @@ -173,6 +177,8 @@ def _chunk_faces(chunks, shape):
The chunk specification of the array.
shape : tuple of int
The shape of the array.
structure: array of bool
Structuring element, shape (3,) * ndim.
Returns
-------
Expand All @@ -182,8 +188,10 @@ def _chunk_faces(chunks, shape):
Examples
--------
>>> import dask.array as da
>>> import scipy.ndimage as ndi
>>> a = da.arange(110, chunks=110).reshape((10, 11)).rechunk(5)
>>> chunk_faces(a.chunks, a.shape)
>>> structure = ndi.generate_binary_structure(2, 1)
>>> chunk_faces(a.chunks, a.shape, structure)
[(slice(4, 6, None), slice(0, 5, None)),
(slice(4, 6, None), slice(5, 10, None)),
(slice(4, 6, None), slice(10, 11, None)),
Expand All @@ -192,16 +200,45 @@ def _chunk_faces(chunks, shape):
(slice(5, 10, None), slice(4, 6, None)),
(slice(5, 10, None), slice(9, 11, None))]
"""
slices = da.core.slices_from_chunks(chunks)

ndim = len(shape)
numblocks = tuple(list(len(c) for c in chunks))

slices = da.core.slices_from_chunks(chunks)

# arrange block/chunk indices on grid
block_summary = np.arange(len(slices)).reshape(numblocks)

faces = []
for ax in range(ndim):
for sl in slices:
if sl[ax].stop == shape[ax]:
continue
slice_to_append = list(sl)
slice_to_append[ax] = slice(sl[ax].stop - 1, sl[ax].stop + 1)
faces.append(tuple(slice_to_append))
for ind_curr_block, curr_block in enumerate(np.ndindex(numblocks)):

for pos_structure_coord in np.array(np.where(structure)).T:

# only consider forward neighbors
if min(pos_structure_coord) < 1 or \
max(pos_structure_coord) < 2: continue

neigh_block = [curr_block[dim] + pos_structure_coord[dim] - 1
for dim in range(ndim)]

if max([neigh_block[dim] >= numblocks[dim] for dim in range(ndim)]): continue

# get neighbor slice index
ind_neigh_block = block_summary[tuple(neigh_block)]

curr_slice = []
for dim in range(ndim):
# keep slice if not on boundary
if slices[ind_curr_block][dim] == slices[ind_neigh_block][dim]:
curr_slice.append(slices[ind_curr_block][dim])
# otherwise, add two-pixel-wide boundary
else:
curr_slice.append(slice(
slices[ind_curr_block][dim].stop - 1,
slices[ind_curr_block][dim].stop + 1))

faces.append(tuple(curr_slice))

return faces


Expand Down
36 changes: 35 additions & 1 deletion tests/test_dask_image/test_ndmeasure/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def _assert_equivalent_labeling(labels0, labels1):
(42, 0.4, (15, 16), (15, 16), 1),
(42, 0.4, (15, 16), (4, 5), 1),
(42, 0.4, (15, 16), (4, 5), 2),
(42, 0.4, (15, 16), (4, 5), None),
(42, 0.4, (15, 16), (8, 5), 1),
(42, 0.4, (15, 16), (8, 5), 2),
(42, 0.3, (10, 8, 6), (5, 4, 3), 1),
Expand All @@ -350,7 +351,10 @@ def test_label(seed, prob, shape, chunks, connectivity):
a = np.random.random(shape) < prob
d = da.from_array(a, chunks=chunks)

s = scipy.ndimage.generate_binary_structure(a.ndim, connectivity)
if connectivity is None:
s = None
else:
s = scipy.ndimage.generate_binary_structure(a.ndim, connectivity)

a_l, a_nl = scipy.ndimage.label(a, s)
d_l, d_nl = dask_image.ndmeasure.label(d, s)
Expand All @@ -362,6 +366,36 @@ def test_label(seed, prob, shape, chunks, connectivity):
_assert_equivalent_labeling(a_l, d_l.compute())


@pytest.mark.parametrize(
"ndim", (2, 3, 4, 5)
)
def test_label_full_struct_element(ndim):

full_s = scipy.ndimage.generate_binary_structure(ndim, ndim)
orth_s = scipy.ndimage.generate_binary_structure(ndim, ndim - 1)

# create a mask that represents a single connected component
# under the full (highest rank) structuring element
# but several connected components under the orthogonal
# structuring element
mask = full_s ^ orth_s
mask[tuple([1] * ndim)] = True

# create dask array with chunk boundary
# that passes through the mask
mask_da = da.from_array(mask, chunks=[2] * ndim)

labels_ndi, N_ndi = scipy.ndimage.label(mask, structure=full_s)
labels_di_da, N_di_da = dask_image.ndmeasure.label(
mask_da, structure=full_s)

assert N_ndi == N_di_da.compute()

_assert_equivalent_labeling(
labels_ndi,
labels_di_da.compute())


@pytest.mark.parametrize(
"shape, chunks, ind", [
((15, 16), (4, 5), None),
Expand Down

0 comments on commit 91ea280

Please sign in to comment.