Skip to content

Commit

Permalink
Fix number of issues in SpectralDistortionIndex (#1808)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored May 29, 2023
1 parent f3c6c64 commit 828c71b
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 14 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed lookup for punkt sources being downloaded in `RougeScore` ([#1789](https://github.com/Lightning-AI/torchmetrics/pull/1789))


- Fixed several bugs in `SpectralDistortionIndex` metric ([#1808](https://github.com/Lightning-AI/torchmetrics/pull/1808))


- Fixed bug for corner cases in `MatthewsCorrCoef` ([#1812](https://github.com/Lightning-AI/torchmetrics/pull/1812))


Expand Down
33 changes: 26 additions & 7 deletions src/torchmetrics/functional/image/d_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from typing_extensions import Literal

from torchmetrics.functional.image.uqi import universal_image_quality_index
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.distributed import reduce


Expand All @@ -34,11 +33,15 @@ def _spectral_distortion_index_update(preds: Tensor, target: Tensor) -> Tuple[Te
raise TypeError(
f"Expected `ms` and `fused` to have the same data type. Got ms: {preds.dtype} and fused: {target.dtype}."
)
_check_same_shape(preds, target)
if len(preds.shape) != 4:
raise ValueError(
f"Expected `preds` and `target` to have BxCxHxW shape. Got preds: {preds.shape} and target: {target.shape}."
)
if preds.shape[:2] != target.shape[:2]:
raise ValueError(
"Expected `preds` and `target` to have same batch and channel sizes."
f"Got preds: {preds.shape} and target: {target.shape}."
)
return preds, target


Expand Down Expand Up @@ -69,13 +72,29 @@ def _spectral_distortion_index_compute(
tensor(0.0234)
"""
length = preds.shape[1]
m1 = torch.zeros((length, length))
m2 = torch.zeros((length, length))

m1 = torch.zeros((length, length), device=preds.device)
m2 = torch.zeros((length, length), device=preds.device)

for k in range(length):
for r in range(k, length):
m1[k, r] = m1[r, k] = universal_image_quality_index(target[:, k : k + 1, :, :], target[:, r : r + 1, :, :])
m2[k, r] = m2[r, k] = universal_image_quality_index(preds[:, k : k + 1, :, :], preds[:, r : r + 1, :, :])
num = length - (k + 1)
if num == 0:
continue
stack1 = target[:, k : k + 1, :, :].repeat(num, 1, 1, 1)
stack2 = torch.cat([target[:, r : r + 1, :, :] for r in range(k + 1, length)], dim=0)
score = [
s.mean() for s in universal_image_quality_index(stack1, stack2, reduction="none").split(preds.shape[0])
]
m1[k, k + 1 :] = torch.stack(score, 0)

stack1 = preds[:, k : k + 1, :, :].repeat(num, 1, 1, 1)
stack2 = torch.cat([preds[:, r : r + 1, :, :] for r in range(k + 1, length)], dim=0)
score = [
s.mean() for s in universal_image_quality_index(stack1, stack2, reduction="none").split(preds.shape[0])
]
m2[k, k + 1 :] = torch.stack(score, 0)
m1 = m1 + m1.T
m2 = m2 + m2.T

diff = torch.pow(torch.abs(m1 - m2), p)
# Special case: when number of channels (L) is 1, there will be only one element in M1 and M2. Hence no need to sum.
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/image/uqi.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _uqi_compute(
sigma_pred_target = output_list[4] - mu_pred_target

upper = 2 * sigma_pred_target
lower = sigma_pred_sq + sigma_target_sq
lower = sigma_pred_sq + sigma_target_sq + torch.finfo(sigma_pred_sq.dtype).eps

uqi_idx = ((2 * mu_pred_target) * upper) / ((mu_pred_sq + mu_target_sq) * lower)
uqi_idx = uqi_idx[..., pad_h:-pad_h, pad_w:-pad_w]
Expand Down
14 changes: 8 additions & 6 deletions tests/unittests/image/test_d_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,6 @@ def test_d_lambda_functional(self, preds, target, p):
metric_args={"p": p},
)

# SpectralDistortionIndex half + cpu does not work due to missing support in torch.log
@pytest.mark.xfail(reason="Spectral Distortion Index metric does not support cpu + half precision")
def test_d_lambda_half_cpu(self, preds, target, p):
"""Test dtype support of the metric on CPU."""
self.run_precision_test_cpu(preds, target, SpectralDistortionIndex, spectral_distortion_index, {"p": p})

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_d_lambda_half_gpu(self, preds, target, p):
"""Test dtype support of the metric on GPU."""
Expand Down Expand Up @@ -152,3 +146,11 @@ def test_d_lambda_invalid_type():
target_t = torch.rand((1, 1, 16, 16), dtype=torch.float64)
with pytest.raises(TypeError, match="Expected `ms` and `fused` to have the same data type.*"):
spectral_distortion_index(preds_t, target_t, p=1)


def test_d_lambda_different_sizes():
"""Since d lambda is reference free, it can accept different number of targets and preds."""
preds = torch.rand(1, 1, 32, 32)
target = torch.rand(1, 1, 16, 16)
out = spectral_distortion_index(preds, target, p=1)
assert isinstance(out, torch.Tensor)

0 comments on commit 828c71b

Please sign in to comment.