Skip to content

Commit

Permalink
Fix input validation for Spectral Distortion Index
Browse files Browse the repository at this point in the history
  • Loading branch information
rddunphy committed Feb 16, 2023
1 parent c4e19f7 commit ffcea15
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/torchmetrics/functional/image/d_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,14 @@ def _spectral_distortion_index_update(preds: Tensor, target: Tensor) -> Tuple[Te
raise TypeError(
f"Expected `ms` and `fused` to have the same data type. Got ms: {preds.dtype} and fused: {target.dtype}."
)
_check_same_shape(preds, target)
if len(preds.shape) != 4:
raise ValueError(
f"Expected `preds` and `target` to have BxCxHxW shape. Got preds: {preds.shape} and target: {target.shape}."
)
if preds.shape[:2] != target.shape[:2]:
raise ValueError(
f"Expected `preds` and `target` to have same batch and channel sizes. Got preds: {preds.shape} and target: {target.shape}."
)
return preds, target


Expand Down Expand Up @@ -70,11 +73,11 @@ def _spectral_distortion_index_compute(
tensor(0.0234)
"""
length = preds.shape[1]
m1 = torch.zeros((length, length))
m2 = torch.zeros((length, length))
m1 = torch.ones((length, length))
m2 = torch.ones((length, length))

for k in range(length):
for r in range(k, length):
for r in range(k+1, length):
m1[k, r] = m1[r, k] = universal_image_quality_index(target[:, k : k + 1, :, :], target[:, r : r + 1, :, :])
m2[k, r] = m2[r, k] = universal_image_quality_index(preds[:, k : k + 1, :, :], preds[:, r : r + 1, :, :])

Expand Down

0 comments on commit ffcea15

Please sign in to comment.