From 0fd1f963c5682ca59047ceb8ebb67eead7d86f87 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Fri, 18 Oct 2024 15:55:13 +0200 Subject: [PATCH] Fix overflow for `PSNR` metric when used with `uint8` input (#2788) --- CHANGELOG.md | 3 +++ src/torchmetrics/functional/image/psnr.py | 5 +++++ tests/unittests/image/test_psnr.py | 13 +++++++++++++ 3 files changed, 21 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b4033b60d9..3569aead873 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,6 +56,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed corner case in `Iou` metric for single empty prediction tensors ([#2780](https://github.com/Lightning-AI/torchmetrics/pull/2780)) +- Fixed `PSNR` calculation for integer type input images ([#2788](https://github.com/Lightning-AI/torchmetrics/pull/2788)) + + ## [1.4.3] - 2024-10-10 ### Fixed diff --git a/src/torchmetrics/functional/image/psnr.py b/src/torchmetrics/functional/image/psnr.py index 8f6a3f20dba..adb80e2bc77 100644 --- a/src/torchmetrics/functional/image/psnr.py +++ b/src/torchmetrics/functional/image/psnr.py @@ -69,6 +69,11 @@ def _psnr_update( Default is None meaning scores will be reduced across all dimensions. """ + if not preds.is_floating_point(): + preds = preds.to(torch.float32) + if not target.is_floating_point(): + target = target.to(torch.float32) + if dim is None: sum_squared_error = torch.sum(torch.pow(preds - target, 2)) num_obs = tensor(target.numel(), device=target.device) diff --git a/tests/unittests/image/test_psnr.py b/tests/unittests/image/test_psnr.py index 66724af1f4c..cdb0e58aac4 100644 --- a/tests/unittests/image/test_psnr.py +++ b/tests/unittests/image/test_psnr.py @@ -167,3 +167,16 @@ def test_missing_data_range(): with pytest.raises(ValueError, match="The `data_range` must be given when `dim` is not None."): peak_signal_noise_ratio(_inputs[0].preds, _inputs[0].target, data_range=None, dim=0) + + +def test_psnr_uint_dtype(): + """Check that automatic casting to float is done for uint dtype. + + See issue: https://github.com/Lightning-AI/torchmetrics/issues/2787 + + """ + preds = torch.randint(0, 255, _input_size, dtype=torch.uint8) + target = torch.randint(0, 255, _input_size, dtype=torch.uint8) + psnr = peak_signal_noise_ratio(preds, target) + prnr2 = peak_signal_noise_ratio(preds.float(), target.float()) + assert torch.allclose(psnr, prnr2)