Skip to content

Commit 7c8d801

Browse files
Refactor PSNR and SSIM (#2797)
* Refactor PSNR * Parametrize on available devices * Refactor ssim * Refactor SSIM * Fix a redundant Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent e8e42c2 commit 7c8d801

File tree

5 files changed

+135
-307
lines changed

5 files changed

+135
-307
lines changed

ignite/metrics/psnr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
118118
self._num_examples += y.shape[0]
119119

120120
@sync_all_reduce("_sum_of_batchwise_psnr", "_num_examples")
121-
def compute(self) -> torch.Tensor:
121+
def compute(self) -> float:
122122
if self._num_examples == 0:
123123
raise NotComputableError("PSNR must have at least one example before it can be computed.")
124-
return self._sum_of_batchwise_psnr / self._num_examples
124+
return (self._sum_of_batchwise_psnr / self._num_examples).item()

ignite/metrics/ssim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
180180
self._num_examples += y.shape[0]
181181

182182
@sync_all_reduce("_sum_of_ssim", "_num_examples")
183-
def compute(self) -> torch.Tensor:
183+
def compute(self) -> float:
184184
if self._num_examples == 0:
185185
raise NotComputableError("SSIM must have at least one example before it can be computed.")
186-
return self._sum_of_ssim / self._num_examples
186+
return (self._sum_of_ssim / self._num_examples).item()

tests/ignite/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,16 @@
1313
import ignite.distributed as idist
1414

1515

16+
@pytest.fixture(
17+
params=[
18+
"cpu",
19+
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no CUDA support")),
20+
]
21+
)
22+
def available_device(request):
23+
return request.param
24+
25+
1626
@pytest.fixture()
1727
def dirname():
1828
path = Path(tempfile.mkdtemp())

tests/ignite/metrics/test_psnr.py

Lines changed: 52 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import os
2-
31
import numpy as np
42
import pytest
53
import torch
@@ -11,8 +9,6 @@
119
from ignite.metrics import PSNR
1210
from ignite.utils import manual_seed
1311

14-
from tests.ignite import cpu_and_maybe_cuda
15-
1612

1713
def test_zero_div():
1814
psnr = PSNR(1.0)
@@ -32,9 +28,32 @@ def test_invalid_psnr():
3228
psnr.update((y_pred, y.squeeze(dim=0)))
3329

3430

35-
def _test_psnr(y_pred, y, data_range, device):
36-
psnr = PSNR(data_range=data_range, device=device)
37-
psnr.update((y_pred, y))
31+
@pytest.fixture(params=["float", "YCbCr", "uint8", "NHW shape"])
32+
def test_data(request, available_device):
33+
manual_seed(42)
34+
if request.param == "float":
35+
y_pred = torch.rand(8, 3, 28, 28, device=available_device)
36+
y = y_pred * 0.8
37+
elif request.param == "YCbCr":
38+
y_pred = torch.randint(16, 236, (4, 1, 12, 12), dtype=torch.uint8, device=available_device)
39+
y = torch.randint(16, 236, (4, 1, 12, 12), dtype=torch.uint8, device=available_device)
40+
elif request.param == "uint8":
41+
y_pred = torch.randint(0, 256, (4, 3, 16, 16), dtype=torch.uint8, device=available_device)
42+
y = (y_pred * 0.8).to(torch.uint8)
43+
elif request.param == "NHW shape":
44+
y_pred = torch.rand(8, 28, 28, device=available_device)
45+
y = y_pred * 0.8
46+
else:
47+
raise ValueError(f"Wrong fixture parameter, given {request.param}")
48+
return (y_pred, y)
49+
50+
51+
def test_psnr(test_data, available_device):
52+
y_pred, y = test_data
53+
data_range = (y.max() - y.min()).cpu().item()
54+
55+
psnr = PSNR(data_range=data_range, device=available_device)
56+
psnr.update(test_data)
3857
psnr_compute = psnr.compute()
3958

4059
np_y_pred = y_pred.cpu().numpy()
@@ -43,43 +62,9 @@ def _test_psnr(y_pred, y, data_range, device):
4362
for np_y_pred_, np_y_ in zip(np_y_pred, np_y):
4463
np_psnr += ski_psnr(np_y_, np_y_pred_, data_range=data_range)
4564

46-
assert torch.gt(psnr_compute, 0.0)
47-
assert isinstance(psnr_compute, torch.Tensor)
48-
assert psnr_compute.dtype == torch.float64
49-
assert psnr_compute.device.type == torch.device(device).type
50-
assert np.allclose(psnr_compute.cpu().numpy(), np_psnr / np_y.shape[0])
51-
52-
53-
@pytest.mark.parametrize("device", cpu_and_maybe_cuda())
54-
def test_psnr(device):
55-
56-
# test for float
57-
manual_seed(42)
58-
y_pred = torch.rand(8, 3, 28, 28, device=device)
59-
y = y_pred * 0.8
60-
data_range = (y.max() - y.min()).cpu().item()
61-
_test_psnr(y_pred, y, data_range, device)
62-
63-
# test for YCbCr
64-
manual_seed(42)
65-
y_pred = torch.randint(16, 236, (4, 1, 12, 12), dtype=torch.uint8, device=device)
66-
y = torch.randint(16, 236, (4, 1, 12, 12), dtype=torch.uint8, device=device)
67-
data_range = (y.max() - y.min()).cpu().item()
68-
_test_psnr(y_pred, y, data_range, device)
69-
70-
# test for uint8
71-
manual_seed(42)
72-
y_pred = torch.randint(0, 256, (4, 3, 16, 16), dtype=torch.uint8, device=device)
73-
y = (y_pred * 0.8).to(torch.uint8)
74-
data_range = (y.max() - y.min()).cpu().item()
75-
_test_psnr(y_pred, y, data_range, device)
76-
77-
# test with NHW shape
78-
manual_seed(42)
79-
y_pred = torch.rand(8, 28, 28, device=device)
80-
y = y_pred * 0.8
81-
data_range = (y.max() - y.min()).cpu().item()
82-
_test_psnr(y_pred, y, data_range, device)
65+
assert psnr_compute > 0.0
66+
assert isinstance(psnr_compute, float)
67+
assert np.allclose(psnr_compute, np_psnr / np_y.shape[0])
8368

8469

8570
def _test(
@@ -109,9 +94,9 @@ def update(engine, i):
10994
y = idist.all_gather(y)
11095
y_pred = idist.all_gather(y_pred)
11196

97+
assert "psnr" in engine.state.metrics
11298
result = engine.state.metrics["psnr"]
11399
assert result > 0.0
114-
assert "psnr" in engine.state.metrics
115100

116101
if compute_y_channel:
117102
np_y_pred = y_pred[:, 0, ...].cpu().numpy()
@@ -127,7 +112,9 @@ def update(engine, i):
127112
assert np.allclose(result, np_psnr / np_y.shape[0], atol=atol)
128113

129114

130-
def _test_distrib_input_float(device, atol=1e-8):
115+
def test_distrib_input_float(distributed):
116+
device = idist.device()
117+
131118
def get_test_cases():
132119

133120
y_pred = torch.rand(n_iters * batch_size, 2, 2, device=device)
@@ -143,12 +130,14 @@ def get_test_cases():
143130
# check multiple random inputs as random exact occurencies are rare
144131
torch.manual_seed(42 + rank + i)
145132
y_pred, y = get_test_cases()
146-
_test(y_pred, y, 1, "cpu", n_iters, batch_size, atol=atol)
133+
_test(y_pred, y, 1, "cpu", n_iters, batch_size, atol=1e-8)
147134
if device.type != "xla":
148-
_test(y_pred, y, 1, idist.device(), n_iters, batch_size, atol=atol)
135+
_test(y_pred, y, 1, idist.device(), n_iters, batch_size, atol=1e-8)
136+
149137

138+
def test_distrib_multilabel_input_YCbCr(distributed):
139+
device = idist.device()
150140

151-
def _test_distrib_multilabel_input_YCbCr(device, atol=1e-8):
152141
def get_test_cases():
153142

154143
y_pred = torch.randint(16, 236, (n_iters * batch_size, 1, 12, 12), dtype=torch.uint8, device=device)
@@ -171,13 +160,15 @@ def out_fn(x):
171160
# check multiple random inputs as random exact occurencies are rare
172161
torch.manual_seed(42 + rank + i)
173162
y_pred, y = get_test_cases()
174-
_test(y_pred, y, 220, "cpu", n_iters, batch_size, atol, output_transform=out_fn, compute_y_channel=True)
163+
_test(y_pred, y, 220, "cpu", n_iters, batch_size, atol=1e-8, output_transform=out_fn, compute_y_channel=True)
175164
if device.type != "xla":
176165
dev = idist.device()
177-
_test(y_pred, y, 220, dev, n_iters, batch_size, atol, output_transform=out_fn, compute_y_channel=True)
166+
_test(y_pred, y, 220, dev, n_iters, batch_size, atol=1e-8, output_transform=out_fn, compute_y_channel=True)
178167

179168

180-
def _test_distrib_multilabel_input_uint8(device, atol=1e-8):
169+
def test_distrib_multilabel_input_uint8(distributed):
170+
device = idist.device()
171+
181172
def get_test_cases():
182173

183174
y_pred = torch.randint(0, 256, (n_iters * batch_size, 3, 16, 16), device=device, dtype=torch.uint8)
@@ -193,12 +184,14 @@ def get_test_cases():
193184
# check multiple random inputs as random exact occurencies are rare
194185
torch.manual_seed(42 + rank + i)
195186
y_pred, y = get_test_cases()
196-
_test(y_pred, y, 100, "cpu", n_iters, batch_size, atol)
187+
_test(y_pred, y, 100, "cpu", n_iters, batch_size, atol=1e-8)
197188
if device.type != "xla":
198-
_test(y_pred, y, 100, idist.device(), n_iters, batch_size, atol)
189+
_test(y_pred, y, 100, idist.device(), n_iters, batch_size, atol=1e-8)
199190

200191

201-
def _test_distrib_multilabel_input_NHW(device, atol=1e-8):
192+
def test_distrib_multilabel_input_NHW(distributed):
193+
device = idist.device()
194+
202195
def get_test_cases():
203196

204197
y_pred = torch.rand(n_iters * batch_size, 28, 28, device=device)
@@ -214,13 +207,13 @@ def get_test_cases():
214207
# check multiple random inputs as random exact occurencies are rare
215208
torch.manual_seed(42 + rank + i)
216209
y_pred, y = get_test_cases()
217-
_test(y_pred, y, 10, "cpu", n_iters, batch_size, atol)
210+
_test(y_pred, y, 10, "cpu", n_iters, batch_size, atol=1e-8)
218211
if device.type != "xla":
219-
_test(y_pred, y, 10, idist.device(), n_iters, batch_size, atol)
212+
_test(y_pred, y, 10, idist.device(), n_iters, batch_size, atol=1e-8)
220213

221214

222-
def _test_distrib_accumulator_device(device):
223-
215+
def test_distrib_accumulator_device(distributed):
216+
device = idist.device()
224217
metric_devices = [torch.device("cpu")]
225218
if torch.device(device).type != "xla":
226219
metric_devices.append(idist.device())
@@ -235,99 +228,3 @@ def _test_distrib_accumulator_device(device):
235228
psnr.update((y_pred, y))
236229
dev = psnr._sum_of_batchwise_psnr.device
237230
assert dev == metric_device, f"{dev} vs {metric_device}"
238-
239-
240-
@pytest.mark.distributed
241-
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
242-
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):
243-
244-
device = idist.device()
245-
_test_distrib_input_float(device)
246-
_test_distrib_multilabel_input_YCbCr(device)
247-
_test_distrib_multilabel_input_uint8(device)
248-
_test_distrib_multilabel_input_NHW(device)
249-
_test_distrib_accumulator_device(device)
250-
251-
252-
@pytest.mark.distributed
253-
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
254-
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
255-
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):
256-
257-
device = idist.device()
258-
_test_distrib_input_float(device)
259-
_test_distrib_multilabel_input_YCbCr(device)
260-
_test_distrib_multilabel_input_uint8(device)
261-
_test_distrib_multilabel_input_NHW(device)
262-
_test_distrib_accumulator_device(device)
263-
264-
265-
@pytest.mark.multinode_distributed
266-
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
267-
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
268-
def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo):
269-
270-
device = idist.device()
271-
_test_distrib_input_float(device)
272-
_test_distrib_multilabel_input_YCbCr(device)
273-
_test_distrib_multilabel_input_uint8(device)
274-
_test_distrib_multilabel_input_NHW(device)
275-
_test_distrib_accumulator_device(device)
276-
277-
278-
@pytest.mark.multinode_distributed
279-
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
280-
@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
281-
def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl):
282-
283-
device = idist.device()
284-
_test_distrib_input_float(device)
285-
_test_distrib_multilabel_input_YCbCr(device)
286-
_test_distrib_multilabel_input_uint8(device)
287-
_test_distrib_multilabel_input_NHW(device)
288-
_test_distrib_accumulator_device(device)
289-
290-
291-
@pytest.mark.tpu
292-
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
293-
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
294-
def test_distrib_single_device_xla():
295-
296-
device = idist.device()
297-
_test_distrib_input_float(device)
298-
_test_distrib_multilabel_input_YCbCr(device)
299-
_test_distrib_multilabel_input_uint8(device)
300-
_test_distrib_multilabel_input_NHW(device)
301-
_test_distrib_accumulator_device(device)
302-
303-
304-
def _test_distrib_xla_nprocs(index):
305-
device = idist.device()
306-
_test_distrib_input_float(device)
307-
_test_distrib_multilabel_input_YCbCr(device)
308-
_test_distrib_multilabel_input_uint8(device)
309-
_test_distrib_multilabel_input_NHW(device)
310-
_test_distrib_accumulator_device(device)
311-
312-
313-
@pytest.mark.tpu
314-
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
315-
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
316-
def test_distrib_xla_nprocs(xmp_executor):
317-
n = int(os.environ["NUM_TPU_WORKERS"])
318-
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)
319-
320-
321-
@pytest.mark.distributed
322-
@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support")
323-
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
324-
def test_distrib_hvd(gloo_hvd_executor):
325-
326-
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
327-
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()
328-
329-
gloo_hvd_executor(_test_distrib_input_float, (device,), np=nproc, do_init=True)
330-
gloo_hvd_executor(_test_distrib_multilabel_input_YCbCr, (device,), np=nproc, do_init=True)
331-
gloo_hvd_executor(_test_distrib_multilabel_input_uint8, (device,), np=nproc, do_init=True)
332-
gloo_hvd_executor(_test_distrib_multilabel_input_NHW, (device,), np=nproc, do_init=True)
333-
gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True)

0 commit comments

Comments
 (0)