Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved error reporting and tests for prune_paths() methods #212

Merged
merged 6 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions src/skan/csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from skimage import morphology
from skimage.graph import central_pixel
from skimage.util._map_array import map_array, ArrayMap
from typing import List
import numba
import warnings

from .nputil import _raveled_offsets_and_distances
from .summary_utils import find_main_branches


def _weighted_abs_diff(values0, values1, distances):
def _weighted_abs_diff(values0: np.ndarray, values1: np.ndarray, distances: np.ndarray) -> np.ndarray:
"""A default edge function for complete image graphs.

A pixel graph on an image with no edge values and no mask is a very
Expand Down Expand Up @@ -521,6 +522,7 @@ def __init__(
np.full(skeleton_image.ndim, spacing)
)
if keep_images:
self.keep_images = keep_images
self.skeleton_image = skeleton_image
self.source_image = source_image

Expand Down Expand Up @@ -550,7 +552,7 @@ def path(self, index):
start, stop = self.paths.indptr[index:index + 2]
return self.paths.indices[start:stop]

def path_coordinates(self, index):
def path_coordinates(self, index: int):
"""Return the image coordinates of the pixels in the path.

Parameters
Expand All @@ -566,7 +568,7 @@ def path_coordinates(self, index):
path_indices = self.path(index)
return self.coordinates[path_indices]

def path_with_data(self, index):
def path_with_data(self, index: int):
"""Return pixel indices and corresponding pixel values on a path.

Parameters
Expand Down Expand Up @@ -652,9 +654,26 @@ def path_stdev(self):
means = self.path_means()
return np.sqrt(np.clip(sumsq/lengths - means*means, 0, None))

def prune_paths(self, indices) -> 'Skeleton':
def prune_paths(self, indices: List[int]) -> 'Skeleton':
"""Prune nodes from the skeleton.

Parameters
----------
indices: List[int]
List of indices to be removed.

Retruns
-------
Skeleton
A new Skeleton object pruned.
"""
# warning: slow
image_cp = np.copy(self.skeleton_image)
if np.any(np.array(indices) > self.paths.shape[1]):
raise ValueError(
f'The path index {i} does not exist in the '
'summary dataframe. Resummarise the skeleton.'
)
for i in indices:
pixel_ids_to_wipe = self.path(i)
junctions = self.degrees[pixel_ids_to_wipe] > 2
Expand All @@ -668,6 +687,7 @@ def prune_paths(self, indices) -> 'Skeleton':
new_skeleton,
spacing=self.spacing,
source_image=self.source_image,
keep_images=self.keep_images
)

def __array__(self, dtype=None):
Expand All @@ -676,8 +696,8 @@ def __array__(self, dtype=None):


def summarize(
skel: Skeleton, *, value_is_height=False, find_main_branch=False
):
skel: Skeleton, *, value_is_height: bool=False, find_main_branch: bool=False
) -> pd.DataFrame:
"""Compute statistics for every skeleton and branch in ``skel``.

Parameters
Expand Down Expand Up @@ -1037,10 +1057,8 @@ def _simplify_graph(skel):

src_relab, dst_relab = fw_map[src], fw_map[dst]

edges = sparse.coo_matrix(
(distance, (src_relab, dst_relab)),
shape=(n_nodes, n_nodes)
)
edges = sparse.coo_matrix((distance, (src_relab, dst_relab)),
shape=(n_nodes, n_nodes))
dir_csgraph = edges.tocsr()
simp_csgraph = dir_csgraph + dir_csgraph.T # make undirected

Expand Down
58 changes: 57 additions & 1 deletion src/skan/test/test_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from numpy.testing import assert_equal, assert_almost_equal
from skimage.draw import line

from skan import csr
from skan import csr, summarize
from skan._testdata import (
tinycycle, tinyline, skeleton0, skeleton1, skeleton2, skeleton3d,
topograph1d, skeleton4
Expand Down Expand Up @@ -179,6 +179,62 @@ def test_transpose_image():
)


@pytest.mark.parametrize(
"skeleton,prune_branch,target",
[
(
skeleton1, 1,
np.array([[0, 1, 1, 1, 1, 1, 0], [1, 0, 0, 0, 0, 0, 1],
[0, 1, 1, 0, 1, 1, 0], [0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0]])
),
(
skeleton1, 2,
np.array([[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0], [1, 0, 0, 2, 0, 0, 0],
[1, 0, 0, 0, 2, 2, 2]])
),
# There are no isolated cycles to be pruned
(
skeleton1, 3,
np.array([[0, 1, 1, 1, 1, 1, 0], [1, 0, 0, 0, 0, 0, 1],
[0, 3, 2, 0, 1, 1, 0], [3, 0, 0, 4, 0, 0, 0],
[3, 0, 0, 0, 4, 4, 4]])
),
]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eek, looks like I need to fiddle with my yapf config a bit more, I very much dislike this way of closing/matching parens... 😅 (Ignore this comment for this PR, just thinking out loud.)

)
def test_prune_paths(
skeleton: np.ndarray, prune_branch: int, target: np.ndarray
) -> None:
"""Test pruning of paths."""
s = csr.Skeleton(skeleton, keep_images=True)
summary = summarize(s)
indices_to_remove = summary.loc[summary['branch-type'] == prune_branch
].index
pruned = s.prune_paths(indices_to_remove)
np.testing.assert_array_equal(pruned, target)


def test_prune_paths_exception_single_point() -> None:
"""Test exceptions raised when pruning leaves a single point and Skeleton object
can not be created and returned."""
s = csr.Skeleton(skeleton0)
summary = summarize(s)
indices_to_remove = summary.loc[summary['branch-type'] == 1].index
with pytest.raises(ValueError):
s.prune_paths(indices_to_remove)


def test_prune_paths_exception_invalid_path_index() -> None:
"""Test exceptions raised when trying to prune paths that do not exist in the summary. This can arise if skeletons
are not updated correctly during iterative pruning."""
s = csr.Skeleton(skeleton0)
summary = summarize(s)
indices_to_remove = [6]
with pytest.raises(ValueError):
s.prune_paths(indices_to_remove)


def test_fast_graph_center_idx():
s = csr.Skeleton(skeleton0)
i = csr._fast_graph_center_idx(s)
Expand Down