Skip to content

Commit

Permalink
fix,test: support for binary_image in ccl stack
Browse files Browse the repository at this point in the history
  • Loading branch information
william-silversmith committed Oct 9, 2024
1 parent aafabfb commit 454d9cd
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
9 changes: 5 additions & 4 deletions automated_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1491,8 +1491,9 @@ def test_pytorch_integration_ccl_doesnt_crash():
assert isinstance(out, torch.Tensor)
assert torch.all(out == labels)

@pytest.mark.parametrize("binary_image", [False, True])
@pytest.mark.parametrize("connectivity", [6, 26])
def test_connected_components_stack(connectivity):
def test_connected_components_stack(binary_image, connectivity):
stack = [
np.ones([100,100,100], dtype=np.uint32)
for i in range(4)
Expand All @@ -1505,7 +1506,7 @@ def test_connected_components_stack(connectivity):
for i in range(2)
]

arr = cc3d.connected_components_stack(stack, connectivity=connectivity)
arr = cc3d.connected_components_stack(stack, connectivity=connectivity, binary_image=binary_image)

ans = np.ones([100,100,601], dtype=np.uint32)
ans[:,:,400] = 0
Expand All @@ -1516,7 +1517,7 @@ def test_connected_components_stack(connectivity):

image = np.random.randint(0,100, size=[100,100,11], dtype=np.uint8)

ans = cc3d.connected_components(image, connectivity=connectivity)
ans = cc3d.connected_components(image, connectivity=connectivity, binary_image=binary_image)
ans, _ = fastremap.renumber(ans[:])

stack = [
Expand All @@ -1526,7 +1527,7 @@ def test_connected_components_stack(connectivity):
image[:,:,9:11],
]

res = cc3d.connected_components_stack(stack, connectivity=connectivity)
res = cc3d.connected_components_stack(stack, connectivity=connectivity, binary_image=binary_image)
res, _ = fastremap.renumber(res[:])

assert np.all(res[:] == ans[:,:,:11])
Expand Down
8 changes: 6 additions & 2 deletions cc3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ def connected_components_stack(
for x in range(image.shape[0]):
if bottom_cc_labels[x,y] == 0 or top_cc_labels[x,y] == 0:
continue
if bottom_cc_img[x,y] == image[x,y,0]:
if ((not binary_image and bottom_cc_img[x,y] == image[x,y,0])
or (binary_image and bottom_cc_img[x,y] and image[x,y,0])):

equivalences.union(bottom_cc_labels[x,y], top_cc_labels[x,y])
else:
for y in range(image.shape[1]):
Expand All @@ -248,7 +250,9 @@ def connected_components_stack(
if top_cc_labels[x0,y0] == 0:
continue

if bottom_cc_img[x,y] == image[x0,y0,0]:
if ((not binary_image and bottom_cc_img[x,y] == image[x0,y0,0])
or (binary_image and bottom_cc_img[x,y] and image[x0,y0,0])):

equivalences.union(
bottom_cc_labels[x,y], top_cc_labels[x0,y0]
)
Expand Down

0 comments on commit 454d9cd

Please sign in to comment.