Skip to content

Commit

Permalink
correct the padding related calculation errors in SSIM (#2721)
Browse files Browse the repository at this point in the history
* correct the padding related calculation errors in SSIM

* fix doctests

* changelog

* change test

* add test

* fix syntax in test_ssim.py

* fix syntax in test_ssim.py

* fix unittests

* fix tolerance

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
(cherry picked from commit efdb111)
  • Loading branch information
petertheprocess authored and Borda committed Sep 13, 2024
1 parent af5720e commit db564f4
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 35 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Correct the padding related calculation errors in SSIM ([#2721](https://github.com/Lightning-AI/torchmetrics/pull/2721))


## [1.4.1] - 2024-08-02
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/image/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _multiscale_structural_similarity_index_measure(
>>> preds = rand([3, 3, 256, 256])
>>> target = preds * 0.75
>>> _multiscale_structural_similarity_index_measure(preds, target, data_range=1.0)
tensor(0.9627)
tensor(0.9628)
"""
_deprecated_root_import_func("multiscale_structural_similarity_index_measure", "image")
Expand Down
24 changes: 12 additions & 12 deletions src/torchmetrics/functional/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,15 @@ def _ssim_update(
dtype = preds.dtype
gauss_kernel_size = [int(3.5 * s + 0.5) * 2 + 1 for s in sigma]

pad_h = (gauss_kernel_size[0] - 1) // 2
pad_w = (gauss_kernel_size[1] - 1) // 2
if gaussian_kernel:
pad_h = (gauss_kernel_size[0] - 1) // 2
pad_w = (gauss_kernel_size[1] - 1) // 2
else:
pad_h = (kernel_size[0] - 1) // 2
pad_w = (kernel_size[1] - 1) // 2

if is_3d:
pad_d = (gauss_kernel_size[2] - 1) // 2
pad_d = (kernel_size[2] - 1) // 2
preds = _reflection_pad_3d(preds, pad_d, pad_w, pad_h)
target = _reflection_pad_3d(target, pad_d, pad_w, pad_h)
if gaussian_kernel:
Expand Down Expand Up @@ -164,25 +168,21 @@ def _ssim_update(

ssim_idx_full_image = ((2 * mu_pred_target + c1) * upper) / ((mu_pred_sq + mu_target_sq + c1) * lower)

if is_3d:
ssim_idx = ssim_idx_full_image[..., pad_h:-pad_h, pad_w:-pad_w, pad_d:-pad_d]
else:
ssim_idx = ssim_idx_full_image[..., pad_h:-pad_h, pad_w:-pad_w]

if return_contrast_sensitivity:
contrast_sensitivity = upper / lower
if is_3d:
contrast_sensitivity = contrast_sensitivity[..., pad_h:-pad_h, pad_w:-pad_w, pad_d:-pad_d]
else:
contrast_sensitivity = contrast_sensitivity[..., pad_h:-pad_h, pad_w:-pad_w]
return ssim_idx.reshape(ssim_idx.shape[0], -1).mean(-1), contrast_sensitivity.reshape(

return ssim_idx_full_image.reshape(ssim_idx_full_image.shape[0], -1).mean(-1), contrast_sensitivity.reshape(
contrast_sensitivity.shape[0], -1
).mean(-1)

if return_full_image:
return ssim_idx.reshape(ssim_idx.shape[0], -1).mean(-1), ssim_idx_full_image
return ssim_idx_full_image.reshape(ssim_idx_full_image.shape[0], -1).mean(-1), ssim_idx_full_image

return ssim_idx.reshape(ssim_idx.shape[0], -1).mean(-1)
return ssim_idx_full_image.reshape(ssim_idx_full_image.shape[0], -1).mean(-1)


def _ssim_compute(
Expand Down Expand Up @@ -507,7 +507,7 @@ def multiscale_structural_similarity_index_measure(
>>> preds = rand([3, 3, 256, 256])
>>> target = preds * 0.75
>>> multiscale_structural_similarity_index_measure(preds, target, data_range=1.0)
tensor(0.9627)
tensor(0.9628)
References:
[1] Multi-Scale Structural Similarity For Image Quality Assessment by Zhou Wang, Eero P. Simoncelli and Alan C.
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/image/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class _MultiScaleStructuralSimilarityIndexMeasure(MultiScaleStructuralSimilarity
>>> target = preds * 0.75
>>> ms_ssim = _MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
>>> ms_ssim(preds, target)
tensor(0.9627)
tensor(0.9628)
"""

Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class MultiScaleStructuralSimilarityIndexMeasure(Metric):
>>> target = preds * 0.75
>>> ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
>>> ms_ssim(preds, target)
tensor(0.9627)
tensor(0.9628)
"""

Expand Down
57 changes: 38 additions & 19 deletions tests/unittests/image/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,16 @@ def test_ssim_invalid_inputs(pred, target, kernel, sigma, match):
structural_similarity_index_measure(pred, target, kernel_size=kernel, sigma=sigma)


def test_ssim_unequal_kernel_size():
@pytest.mark.parametrize(
("sigma", "kernel_size", "result"),
[
((0.25, 0.5), None, torch.tensor(0.20977394)),
((0.5, 0.25), None, torch.tensor(0.13884821)),
(None, (3, 5), torch.tensor(0.05032664)),
(None, (5, 3), torch.tensor(0.03472072)),
],
)
def test_ssim_unequal_kernel_size(sigma, kernel_size, result):
"""Test the case where kernel_size[0] != kernel_size[1]."""
preds = torch.tensor([
[
Expand Down Expand Up @@ -306,24 +315,18 @@ def test_ssim_unequal_kernel_size():
]
]
])
# kernel order matters
assert torch.isclose(
structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=(0.25, 0.5)),
torch.tensor(0.08869550),
)
assert not torch.isclose(
structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=(0.5, 0.25)),
torch.tensor(0.08869550),
)

assert torch.isclose(
structural_similarity_index_measure(preds, target, gaussian_kernel=False, kernel_size=(3, 5)),
torch.tensor(0.05131844),
)
assert not torch.isclose(
structural_similarity_index_measure(preds, target, gaussian_kernel=False, kernel_size=(5, 3)),
torch.tensor(0.05131844),
)
if sigma is not None:
assert torch.isclose(
structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=sigma),
result,
atol=1e-04,
)
else:
assert torch.isclose(
structural_similarity_index_measure(preds, target, gaussian_kernel=False, kernel_size=kernel_size),
result,
atol=1e-04,
)


@pytest.mark.parametrize(
Expand All @@ -341,3 +344,19 @@ def test_full_image_output(preds, target):
assert len(out) == 2
assert out[0].numel() == 1
assert out[1].shape == preds[0].shape


def test_ssim_for_correct_padding():
"""Check that padding is correctly added and removed for SSIM.
See issue: https://github.com/Lightning-AI/torchmetrics/issues/2718
"""
preds = torch.rand([3, 3, 256, 256])
# let the edge of the image be 0
target = preds.clone()
target[:, :, 0, :] = 0
target[:, :, -1, :] = 0
target[:, :, :, 0] = 0
target[:, :, :, -1] = 0
assert structural_similarity_index_measure(preds, target) < 1.0

0 comments on commit db564f4

Please sign in to comment.