From db564f4761d1f7beda9d8c0a1d9880a98968e4c8 Mon Sep 17 00:00:00 2001 From: TanShaochang <30321432+petertheprocess@users.noreply.github.com> Date: Mon, 9 Sep 2024 08:58:45 +0200 Subject: [PATCH] correct the padding related calculation errors in SSIM (#2721) * correct the padding related calculation errors in SSIM * fix doctests * changelog * change test * add test * fix syntax in test_ssim.py * fix syntax in test_ssim.py * fix unittests * fix tolerance --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte Detlefsen Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> (cherry picked from commit efdb111d924c37d8054be709e17cb28948fb19c8) --- CHANGELOG.md | 2 +- .../functional/image/_deprecated.py | 2 +- src/torchmetrics/functional/image/ssim.py | 24 ++++---- src/torchmetrics/image/_deprecated.py | 2 +- src/torchmetrics/image/ssim.py | 2 +- tests/unittests/image/test_ssim.py | 57 ++++++++++++------- 6 files changed, 54 insertions(+), 35 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 41be5c7e62c..bea26412357 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Correct the padding related calculation errors in SSIM ([#2721](https://github.com/Lightning-AI/torchmetrics/pull/2721)) ## [1.4.1] - 2024-08-02 diff --git a/src/torchmetrics/functional/image/_deprecated.py b/src/torchmetrics/functional/image/_deprecated.py index fafc09cabaa..892d07afaa6 100644 --- a/src/torchmetrics/functional/image/_deprecated.py +++ b/src/torchmetrics/functional/image/_deprecated.py @@ -168,7 +168,7 @@ def _multiscale_structural_similarity_index_measure( >>> preds = rand([3, 3, 256, 256]) >>> target = preds * 0.75 >>> _multiscale_structural_similarity_index_measure(preds, target, data_range=1.0) - tensor(0.9627) + tensor(0.9628) """ _deprecated_root_import_func("multiscale_structural_similarity_index_measure", "image") diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index d89dd828a3e..c61ef833fe3 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -124,11 +124,15 @@ def _ssim_update( dtype = preds.dtype gauss_kernel_size = [int(3.5 * s + 0.5) * 2 + 1 for s in sigma] - pad_h = (gauss_kernel_size[0] - 1) // 2 - pad_w = (gauss_kernel_size[1] - 1) // 2 + if gaussian_kernel: + pad_h = (gauss_kernel_size[0] - 1) // 2 + pad_w = (gauss_kernel_size[1] - 1) // 2 + else: + pad_h = (kernel_size[0] - 1) // 2 + pad_w = (kernel_size[1] - 1) // 2 if is_3d: - pad_d = (gauss_kernel_size[2] - 1) // 2 + pad_d = (kernel_size[2] - 1) // 2 preds = _reflection_pad_3d(preds, pad_d, pad_w, pad_h) target = _reflection_pad_3d(target, pad_d, pad_w, pad_h) if gaussian_kernel: @@ -164,25 +168,21 @@ def _ssim_update( ssim_idx_full_image = ((2 * mu_pred_target + c1) * upper) / ((mu_pred_sq + mu_target_sq + c1) * lower) - if is_3d: - ssim_idx = ssim_idx_full_image[..., pad_h:-pad_h, pad_w:-pad_w, pad_d:-pad_d] - else: - ssim_idx = ssim_idx_full_image[..., pad_h:-pad_h, pad_w:-pad_w] - if return_contrast_sensitivity: contrast_sensitivity = upper / lower if is_3d: contrast_sensitivity = contrast_sensitivity[..., pad_h:-pad_h, pad_w:-pad_w, pad_d:-pad_d] else: contrast_sensitivity = contrast_sensitivity[..., pad_h:-pad_h, pad_w:-pad_w] - return ssim_idx.reshape(ssim_idx.shape[0], -1).mean(-1), contrast_sensitivity.reshape( + + return ssim_idx_full_image.reshape(ssim_idx_full_image.shape[0], -1).mean(-1), contrast_sensitivity.reshape( contrast_sensitivity.shape[0], -1 ).mean(-1) if return_full_image: - return ssim_idx.reshape(ssim_idx.shape[0], -1).mean(-1), ssim_idx_full_image + return ssim_idx_full_image.reshape(ssim_idx_full_image.shape[0], -1).mean(-1), ssim_idx_full_image - return ssim_idx.reshape(ssim_idx.shape[0], -1).mean(-1) + return ssim_idx_full_image.reshape(ssim_idx_full_image.shape[0], -1).mean(-1) def _ssim_compute( @@ -507,7 +507,7 @@ def multiscale_structural_similarity_index_measure( >>> preds = rand([3, 3, 256, 256]) >>> target = preds * 0.75 >>> multiscale_structural_similarity_index_measure(preds, target, data_range=1.0) - tensor(0.9627) + tensor(0.9628) References: [1] Multi-Scale Structural Similarity For Image Quality Assessment by Zhou Wang, Eero P. Simoncelli and Alan C. diff --git a/src/torchmetrics/image/_deprecated.py b/src/torchmetrics/image/_deprecated.py index 86fb04c7f1c..8b382b89cf7 100644 --- a/src/torchmetrics/image/_deprecated.py +++ b/src/torchmetrics/image/_deprecated.py @@ -44,7 +44,7 @@ class _MultiScaleStructuralSimilarityIndexMeasure(MultiScaleStructuralSimilarity >>> target = preds * 0.75 >>> ms_ssim = _MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) >>> ms_ssim(preds, target) - tensor(0.9627) + tensor(0.9628) """ diff --git a/src/torchmetrics/image/ssim.py b/src/torchmetrics/image/ssim.py index f0c03a163d6..648f9c26029 100644 --- a/src/torchmetrics/image/ssim.py +++ b/src/torchmetrics/image/ssim.py @@ -274,7 +274,7 @@ class MultiScaleStructuralSimilarityIndexMeasure(Metric): >>> target = preds * 0.75 >>> ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) >>> ms_ssim(preds, target) - tensor(0.9627) + tensor(0.9628) """ diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index 6b464f2a97b..49954f45cd7 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -278,7 +278,16 @@ def test_ssim_invalid_inputs(pred, target, kernel, sigma, match): structural_similarity_index_measure(pred, target, kernel_size=kernel, sigma=sigma) -def test_ssim_unequal_kernel_size(): +@pytest.mark.parametrize( + ("sigma", "kernel_size", "result"), + [ + ((0.25, 0.5), None, torch.tensor(0.20977394)), + ((0.5, 0.25), None, torch.tensor(0.13884821)), + (None, (3, 5), torch.tensor(0.05032664)), + (None, (5, 3), torch.tensor(0.03472072)), + ], +) +def test_ssim_unequal_kernel_size(sigma, kernel_size, result): """Test the case where kernel_size[0] != kernel_size[1].""" preds = torch.tensor([ [ @@ -306,24 +315,18 @@ def test_ssim_unequal_kernel_size(): ] ] ]) - # kernel order matters - assert torch.isclose( - structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=(0.25, 0.5)), - torch.tensor(0.08869550), - ) - assert not torch.isclose( - structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=(0.5, 0.25)), - torch.tensor(0.08869550), - ) - - assert torch.isclose( - structural_similarity_index_measure(preds, target, gaussian_kernel=False, kernel_size=(3, 5)), - torch.tensor(0.05131844), - ) - assert not torch.isclose( - structural_similarity_index_measure(preds, target, gaussian_kernel=False, kernel_size=(5, 3)), - torch.tensor(0.05131844), - ) + if sigma is not None: + assert torch.isclose( + structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=sigma), + result, + atol=1e-04, + ) + else: + assert torch.isclose( + structural_similarity_index_measure(preds, target, gaussian_kernel=False, kernel_size=kernel_size), + result, + atol=1e-04, + ) @pytest.mark.parametrize( @@ -341,3 +344,19 @@ def test_full_image_output(preds, target): assert len(out) == 2 assert out[0].numel() == 1 assert out[1].shape == preds[0].shape + + +def test_ssim_for_correct_padding(): + """Check that padding is correctly added and removed for SSIM. + + See issue: https://github.com/Lightning-AI/torchmetrics/issues/2718 + + """ + preds = torch.rand([3, 3, 256, 256]) + # let the edge of the image be 0 + target = preds.clone() + target[:, :, 0, :] = 0 + target[:, :, -1, :] = 0 + target[:, :, :, 0] = 0 + target[:, :, :, -1] = 0 + assert structural_similarity_index_measure(preds, target) < 1.0