Skip to content

Commit

Permalink
Merge pull request scikit-image#5 from jni/slic-speed-fix-dims
Browse files Browse the repository at this point in the history
Fix image dimension sanitizing at function start
  • Loading branch information
ahojnnes committed Sep 16, 2013
2 parents bc2f23d + ea1566f commit 06e99a5
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 15 deletions.
31 changes: 16 additions & 15 deletions skimage/segmentation/slic_superpixels.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,12 @@ def slic(image, n_segments=100, compactness=10., max_iter=10, sigma=0,
ValueError
If:
- the image dimension is not 2 or 3 and `multichannel == False`, OR
- the image dimension is not 3 or 4 and `multichannel == True`, OR
- `multichannel == True` and the length of the last dimension of
the image is not 3, OR
- the image dimension is not 3 or 4 and `multichannel == True`
Notes
-----
If `sigma > 0` as is default, the image is smoothed using a Gaussian kernel
prior to segmentation.
If `sigma > 0`, the image is smoothed using a Gaussian kernel prior to
segmentation.
The image is rescaled to be in [0, 1] prior to processing.
Expand Down Expand Up @@ -88,15 +86,18 @@ def slic(image, n_segments=100, compactness=10., max_iter=10, sigma=0,
compactness = ratio

image = img_as_float(image)
image = np.atleast_3d(image)

if image.ndim == 3:
if multichannel:
# Make 2D image 3D with depth = 1
image = image[np.newaxis, ...]
else:
# Add channel as single last dimension
image = image[..., np.newaxis]
is2d = False
if image.ndim == 2:
# 2D grayscale image
image = image[np.newaxis, ..., np.newaxis]
is2d = True
elif image.ndim == 3 and multichannel:
# Make 2D multichannel image 3D with depth = 1
image = image[np.newaxis, ...]
is2d = True
elif image.ndim == 3 and not multichannel:
# Add channel as single last dimension
image = image[..., np.newaxis]

if not isinstance(sigma, coll.Iterable):
sigma = np.array([sigma, sigma, sigma])
Expand Down Expand Up @@ -135,7 +136,7 @@ def slic(image, n_segments=100, compactness=10., max_iter=10, sigma=0,

labels = _slic_cython(image, segments, max_iter)

if labels.shape[0] == 1:
if is2d:
labels = labels[0]

return labels
2 changes: 2 additions & 0 deletions skimage/segmentation/tests/test_slic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def test_color_2d():

# we expect 4 segments
assert_equal(len(np.unique(seg)), 4)
assert_equal(seg.shape, img.shape[:-1])
assert_array_equal(seg[:10, :10], 0)
assert_array_equal(seg[10:, :10], 2)
assert_array_equal(seg[:10, 10:], 1)
Expand All @@ -39,6 +40,7 @@ def test_gray_2d():
multichannel=False, convert2lab=False)

assert_equal(len(np.unique(seg)), 4)
assert_equal(seg.shape, img.shape)
assert_array_equal(seg[:10, :10], 0)
assert_array_equal(seg[10:, :10], 2)
assert_array_equal(seg[:10, 10:], 1)
Expand Down

0 comments on commit 06e99a5

Please sign in to comment.