From 828c71b0d4333aa2c8538e503ca0a08dd28bca8e Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 29 May 2023 20:59:15 +0200 Subject: [PATCH] Fix number of issues in `SpectralDistortionIndex` (#1808) --- CHANGELOG.md | 3 ++ src/torchmetrics/functional/image/d_lambda.py | 33 +++++++++++++++---- src/torchmetrics/functional/image/uqi.py | 2 +- tests/unittests/image/test_d_lambda.py | 14 ++++---- 4 files changed, 38 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f527ee3704..c804e7bcfb2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -206,6 +206,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed lookup for punkt sources being downloaded in `RougeScore` ([#1789](https://github.com/Lightning-AI/torchmetrics/pull/1789)) +- Fixed several bugs in `SpectralDistortionIndex` metric ([#1808](https://github.com/Lightning-AI/torchmetrics/pull/1808)) + + - Fixed bug for corner cases in `MatthewsCorrCoef` ([#1812](https://github.com/Lightning-AI/torchmetrics/pull/1812)) diff --git a/src/torchmetrics/functional/image/d_lambda.py b/src/torchmetrics/functional/image/d_lambda.py index 9741a7e1784..f6856459566 100644 --- a/src/torchmetrics/functional/image/d_lambda.py +++ b/src/torchmetrics/functional/image/d_lambda.py @@ -19,7 +19,6 @@ from typing_extensions import Literal from torchmetrics.functional.image.uqi import universal_image_quality_index -from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.distributed import reduce @@ -34,11 +33,15 @@ def _spectral_distortion_index_update(preds: Tensor, target: Tensor) -> Tuple[Te raise TypeError( f"Expected `ms` and `fused` to have the same data type. Got ms: {preds.dtype} and fused: {target.dtype}." ) - _check_same_shape(preds, target) if len(preds.shape) != 4: raise ValueError( f"Expected `preds` and `target` to have BxCxHxW shape. Got preds: {preds.shape} and target: {target.shape}." ) + if preds.shape[:2] != target.shape[:2]: + raise ValueError( + "Expected `preds` and `target` to have same batch and channel sizes." + f"Got preds: {preds.shape} and target: {target.shape}." + ) return preds, target @@ -69,13 +72,29 @@ def _spectral_distortion_index_compute( tensor(0.0234) """ length = preds.shape[1] - m1 = torch.zeros((length, length)) - m2 = torch.zeros((length, length)) + + m1 = torch.zeros((length, length), device=preds.device) + m2 = torch.zeros((length, length), device=preds.device) for k in range(length): - for r in range(k, length): - m1[k, r] = m1[r, k] = universal_image_quality_index(target[:, k : k + 1, :, :], target[:, r : r + 1, :, :]) - m2[k, r] = m2[r, k] = universal_image_quality_index(preds[:, k : k + 1, :, :], preds[:, r : r + 1, :, :]) + num = length - (k + 1) + if num == 0: + continue + stack1 = target[:, k : k + 1, :, :].repeat(num, 1, 1, 1) + stack2 = torch.cat([target[:, r : r + 1, :, :] for r in range(k + 1, length)], dim=0) + score = [ + s.mean() for s in universal_image_quality_index(stack1, stack2, reduction="none").split(preds.shape[0]) + ] + m1[k, k + 1 :] = torch.stack(score, 0) + + stack1 = preds[:, k : k + 1, :, :].repeat(num, 1, 1, 1) + stack2 = torch.cat([preds[:, r : r + 1, :, :] for r in range(k + 1, length)], dim=0) + score = [ + s.mean() for s in universal_image_quality_index(stack1, stack2, reduction="none").split(preds.shape[0]) + ] + m2[k, k + 1 :] = torch.stack(score, 0) + m1 = m1 + m1.T + m2 = m2 + m2.T diff = torch.pow(torch.abs(m1 - m2), p) # Special case: when number of channels (L) is 1, there will be only one element in M1 and M2. Hence no need to sum. diff --git a/src/torchmetrics/functional/image/uqi.py b/src/torchmetrics/functional/image/uqi.py index b742bb08511..67f8750b3ef 100644 --- a/src/torchmetrics/functional/image/uqi.py +++ b/src/torchmetrics/functional/image/uqi.py @@ -112,7 +112,7 @@ def _uqi_compute( sigma_pred_target = output_list[4] - mu_pred_target upper = 2 * sigma_pred_target - lower = sigma_pred_sq + sigma_target_sq + lower = sigma_pred_sq + sigma_target_sq + torch.finfo(sigma_pred_sq.dtype).eps uqi_idx = ((2 * mu_pred_target) * upper) / ((mu_pred_sq + mu_target_sq) * lower) uqi_idx = uqi_idx[..., pad_h:-pad_h, pad_w:-pad_w] diff --git a/tests/unittests/image/test_d_lambda.py b/tests/unittests/image/test_d_lambda.py index c4a8af6f939..32a9c277b43 100644 --- a/tests/unittests/image/test_d_lambda.py +++ b/tests/unittests/image/test_d_lambda.py @@ -118,12 +118,6 @@ def test_d_lambda_functional(self, preds, target, p): metric_args={"p": p}, ) - # SpectralDistortionIndex half + cpu does not work due to missing support in torch.log - @pytest.mark.xfail(reason="Spectral Distortion Index metric does not support cpu + half precision") - def test_d_lambda_half_cpu(self, preds, target, p): - """Test dtype support of the metric on CPU.""" - self.run_precision_test_cpu(preds, target, SpectralDistortionIndex, spectral_distortion_index, {"p": p}) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") def test_d_lambda_half_gpu(self, preds, target, p): """Test dtype support of the metric on GPU.""" @@ -152,3 +146,11 @@ def test_d_lambda_invalid_type(): target_t = torch.rand((1, 1, 16, 16), dtype=torch.float64) with pytest.raises(TypeError, match="Expected `ms` and `fused` to have the same data type.*"): spectral_distortion_index(preds_t, target_t, p=1) + + +def test_d_lambda_different_sizes(): + """Since d lambda is reference free, it can accept different number of targets and preds.""" + preds = torch.rand(1, 1, 32, 32) + target = torch.rand(1, 1, 16, 16) + out = spectral_distortion_index(preds, target, p=1) + assert isinstance(out, torch.Tensor)