1
- import os
2
-
3
1
import numpy as np
4
2
import pytest
5
3
import torch
11
9
from ignite .metrics import PSNR
12
10
from ignite .utils import manual_seed
13
11
14
- from tests .ignite import cpu_and_maybe_cuda
15
-
16
12
17
13
def test_zero_div ():
18
14
psnr = PSNR (1.0 )
@@ -32,9 +28,32 @@ def test_invalid_psnr():
32
28
psnr .update ((y_pred , y .squeeze (dim = 0 )))
33
29
34
30
35
- def _test_psnr (y_pred , y , data_range , device ):
36
- psnr = PSNR (data_range = data_range , device = device )
37
- psnr .update ((y_pred , y ))
31
+ @pytest .fixture (params = ["float" , "YCbCr" , "uint8" , "NHW shape" ])
32
+ def test_data (request , available_device ):
33
+ manual_seed (42 )
34
+ if request .param == "float" :
35
+ y_pred = torch .rand (8 , 3 , 28 , 28 , device = available_device )
36
+ y = y_pred * 0.8
37
+ elif request .param == "YCbCr" :
38
+ y_pred = torch .randint (16 , 236 , (4 , 1 , 12 , 12 ), dtype = torch .uint8 , device = available_device )
39
+ y = torch .randint (16 , 236 , (4 , 1 , 12 , 12 ), dtype = torch .uint8 , device = available_device )
40
+ elif request .param == "uint8" :
41
+ y_pred = torch .randint (0 , 256 , (4 , 3 , 16 , 16 ), dtype = torch .uint8 , device = available_device )
42
+ y = (y_pred * 0.8 ).to (torch .uint8 )
43
+ elif request .param == "NHW shape" :
44
+ y_pred = torch .rand (8 , 28 , 28 , device = available_device )
45
+ y = y_pred * 0.8
46
+ else :
47
+ raise ValueError (f"Wrong fixture parameter, given { request .param } " )
48
+ return (y_pred , y )
49
+
50
+
51
+ def test_psnr (test_data , available_device ):
52
+ y_pred , y = test_data
53
+ data_range = (y .max () - y .min ()).cpu ().item ()
54
+
55
+ psnr = PSNR (data_range = data_range , device = available_device )
56
+ psnr .update (test_data )
38
57
psnr_compute = psnr .compute ()
39
58
40
59
np_y_pred = y_pred .cpu ().numpy ()
@@ -43,43 +62,9 @@ def _test_psnr(y_pred, y, data_range, device):
43
62
for np_y_pred_ , np_y_ in zip (np_y_pred , np_y ):
44
63
np_psnr += ski_psnr (np_y_ , np_y_pred_ , data_range = data_range )
45
64
46
- assert torch .gt (psnr_compute , 0.0 )
47
- assert isinstance (psnr_compute , torch .Tensor )
48
- assert psnr_compute .dtype == torch .float64
49
- assert psnr_compute .device .type == torch .device (device ).type
50
- assert np .allclose (psnr_compute .cpu ().numpy (), np_psnr / np_y .shape [0 ])
51
-
52
-
53
- @pytest .mark .parametrize ("device" , cpu_and_maybe_cuda ())
54
- def test_psnr (device ):
55
-
56
- # test for float
57
- manual_seed (42 )
58
- y_pred = torch .rand (8 , 3 , 28 , 28 , device = device )
59
- y = y_pred * 0.8
60
- data_range = (y .max () - y .min ()).cpu ().item ()
61
- _test_psnr (y_pred , y , data_range , device )
62
-
63
- # test for YCbCr
64
- manual_seed (42 )
65
- y_pred = torch .randint (16 , 236 , (4 , 1 , 12 , 12 ), dtype = torch .uint8 , device = device )
66
- y = torch .randint (16 , 236 , (4 , 1 , 12 , 12 ), dtype = torch .uint8 , device = device )
67
- data_range = (y .max () - y .min ()).cpu ().item ()
68
- _test_psnr (y_pred , y , data_range , device )
69
-
70
- # test for uint8
71
- manual_seed (42 )
72
- y_pred = torch .randint (0 , 256 , (4 , 3 , 16 , 16 ), dtype = torch .uint8 , device = device )
73
- y = (y_pred * 0.8 ).to (torch .uint8 )
74
- data_range = (y .max () - y .min ()).cpu ().item ()
75
- _test_psnr (y_pred , y , data_range , device )
76
-
77
- # test with NHW shape
78
- manual_seed (42 )
79
- y_pred = torch .rand (8 , 28 , 28 , device = device )
80
- y = y_pred * 0.8
81
- data_range = (y .max () - y .min ()).cpu ().item ()
82
- _test_psnr (y_pred , y , data_range , device )
65
+ assert psnr_compute > 0.0
66
+ assert isinstance (psnr_compute , float )
67
+ assert np .allclose (psnr_compute , np_psnr / np_y .shape [0 ])
83
68
84
69
85
70
def _test (
@@ -109,9 +94,9 @@ def update(engine, i):
109
94
y = idist .all_gather (y )
110
95
y_pred = idist .all_gather (y_pred )
111
96
97
+ assert "psnr" in engine .state .metrics
112
98
result = engine .state .metrics ["psnr" ]
113
99
assert result > 0.0
114
- assert "psnr" in engine .state .metrics
115
100
116
101
if compute_y_channel :
117
102
np_y_pred = y_pred [:, 0 , ...].cpu ().numpy ()
@@ -127,7 +112,9 @@ def update(engine, i):
127
112
assert np .allclose (result , np_psnr / np_y .shape [0 ], atol = atol )
128
113
129
114
130
- def _test_distrib_input_float (device , atol = 1e-8 ):
115
+ def test_distrib_input_float (distributed ):
116
+ device = idist .device ()
117
+
131
118
def get_test_cases ():
132
119
133
120
y_pred = torch .rand (n_iters * batch_size , 2 , 2 , device = device )
@@ -143,12 +130,14 @@ def get_test_cases():
143
130
# check multiple random inputs as random exact occurencies are rare
144
131
torch .manual_seed (42 + rank + i )
145
132
y_pred , y = get_test_cases ()
146
- _test (y_pred , y , 1 , "cpu" , n_iters , batch_size , atol = atol )
133
+ _test (y_pred , y , 1 , "cpu" , n_iters , batch_size , atol = 1e-8 )
147
134
if device .type != "xla" :
148
- _test (y_pred , y , 1 , idist .device (), n_iters , batch_size , atol = atol )
135
+ _test (y_pred , y , 1 , idist .device (), n_iters , batch_size , atol = 1e-8 )
136
+
149
137
138
+ def test_distrib_multilabel_input_YCbCr (distributed ):
139
+ device = idist .device ()
150
140
151
- def _test_distrib_multilabel_input_YCbCr (device , atol = 1e-8 ):
152
141
def get_test_cases ():
153
142
154
143
y_pred = torch .randint (16 , 236 , (n_iters * batch_size , 1 , 12 , 12 ), dtype = torch .uint8 , device = device )
@@ -171,13 +160,15 @@ def out_fn(x):
171
160
# check multiple random inputs as random exact occurencies are rare
172
161
torch .manual_seed (42 + rank + i )
173
162
y_pred , y = get_test_cases ()
174
- _test (y_pred , y , 220 , "cpu" , n_iters , batch_size , atol , output_transform = out_fn , compute_y_channel = True )
163
+ _test (y_pred , y , 220 , "cpu" , n_iters , batch_size , atol = 1e-8 , output_transform = out_fn , compute_y_channel = True )
175
164
if device .type != "xla" :
176
165
dev = idist .device ()
177
- _test (y_pred , y , 220 , dev , n_iters , batch_size , atol , output_transform = out_fn , compute_y_channel = True )
166
+ _test (y_pred , y , 220 , dev , n_iters , batch_size , atol = 1e-8 , output_transform = out_fn , compute_y_channel = True )
178
167
179
168
180
- def _test_distrib_multilabel_input_uint8 (device , atol = 1e-8 ):
169
+ def test_distrib_multilabel_input_uint8 (distributed ):
170
+ device = idist .device ()
171
+
181
172
def get_test_cases ():
182
173
183
174
y_pred = torch .randint (0 , 256 , (n_iters * batch_size , 3 , 16 , 16 ), device = device , dtype = torch .uint8 )
@@ -193,12 +184,14 @@ def get_test_cases():
193
184
# check multiple random inputs as random exact occurencies are rare
194
185
torch .manual_seed (42 + rank + i )
195
186
y_pred , y = get_test_cases ()
196
- _test (y_pred , y , 100 , "cpu" , n_iters , batch_size , atol )
187
+ _test (y_pred , y , 100 , "cpu" , n_iters , batch_size , atol = 1e-8 )
197
188
if device .type != "xla" :
198
- _test (y_pred , y , 100 , idist .device (), n_iters , batch_size , atol )
189
+ _test (y_pred , y , 100 , idist .device (), n_iters , batch_size , atol = 1e-8 )
199
190
200
191
201
- def _test_distrib_multilabel_input_NHW (device , atol = 1e-8 ):
192
+ def test_distrib_multilabel_input_NHW (distributed ):
193
+ device = idist .device ()
194
+
202
195
def get_test_cases ():
203
196
204
197
y_pred = torch .rand (n_iters * batch_size , 28 , 28 , device = device )
@@ -214,13 +207,13 @@ def get_test_cases():
214
207
# check multiple random inputs as random exact occurencies are rare
215
208
torch .manual_seed (42 + rank + i )
216
209
y_pred , y = get_test_cases ()
217
- _test (y_pred , y , 10 , "cpu" , n_iters , batch_size , atol )
210
+ _test (y_pred , y , 10 , "cpu" , n_iters , batch_size , atol = 1e-8 )
218
211
if device .type != "xla" :
219
- _test (y_pred , y , 10 , idist .device (), n_iters , batch_size , atol )
212
+ _test (y_pred , y , 10 , idist .device (), n_iters , batch_size , atol = 1e-8 )
220
213
221
214
222
- def _test_distrib_accumulator_device ( device ):
223
-
215
+ def test_distrib_accumulator_device ( distributed ):
216
+ device = idist . device ()
224
217
metric_devices = [torch .device ("cpu" )]
225
218
if torch .device (device ).type != "xla" :
226
219
metric_devices .append (idist .device ())
@@ -235,99 +228,3 @@ def _test_distrib_accumulator_device(device):
235
228
psnr .update ((y_pred , y ))
236
229
dev = psnr ._sum_of_batchwise_psnr .device
237
230
assert dev == metric_device , f"{ dev } vs { metric_device } "
238
-
239
-
240
- @pytest .mark .distributed
241
- @pytest .mark .skipif (not idist .has_native_dist_support , reason = "Skip if no native dist support" )
242
- def test_distrib_gloo_cpu_or_gpu (distributed_context_single_node_gloo ):
243
-
244
- device = idist .device ()
245
- _test_distrib_input_float (device )
246
- _test_distrib_multilabel_input_YCbCr (device )
247
- _test_distrib_multilabel_input_uint8 (device )
248
- _test_distrib_multilabel_input_NHW (device )
249
- _test_distrib_accumulator_device (device )
250
-
251
-
252
- @pytest .mark .distributed
253
- @pytest .mark .skipif (not idist .has_native_dist_support , reason = "Skip if no native dist support" )
254
- @pytest .mark .skipif (torch .cuda .device_count () < 1 , reason = "Skip if no GPU" )
255
- def test_distrib_nccl_gpu (distributed_context_single_node_nccl ):
256
-
257
- device = idist .device ()
258
- _test_distrib_input_float (device )
259
- _test_distrib_multilabel_input_YCbCr (device )
260
- _test_distrib_multilabel_input_uint8 (device )
261
- _test_distrib_multilabel_input_NHW (device )
262
- _test_distrib_accumulator_device (device )
263
-
264
-
265
- @pytest .mark .multinode_distributed
266
- @pytest .mark .skipif (not idist .has_native_dist_support , reason = "Skip if no native dist support" )
267
- @pytest .mark .skipif ("MULTINODE_DISTRIB" not in os .environ , reason = "Skip if not multi-node distributed" )
268
- def test_multinode_distrib_gloo_cpu_or_gpu (distributed_context_multi_node_gloo ):
269
-
270
- device = idist .device ()
271
- _test_distrib_input_float (device )
272
- _test_distrib_multilabel_input_YCbCr (device )
273
- _test_distrib_multilabel_input_uint8 (device )
274
- _test_distrib_multilabel_input_NHW (device )
275
- _test_distrib_accumulator_device (device )
276
-
277
-
278
- @pytest .mark .multinode_distributed
279
- @pytest .mark .skipif (not idist .has_native_dist_support , reason = "Skip if no native dist support" )
280
- @pytest .mark .skipif ("GPU_MULTINODE_DISTRIB" not in os .environ , reason = "Skip if not multi-node distributed" )
281
- def test_multinode_distrib_nccl_gpu (distributed_context_multi_node_nccl ):
282
-
283
- device = idist .device ()
284
- _test_distrib_input_float (device )
285
- _test_distrib_multilabel_input_YCbCr (device )
286
- _test_distrib_multilabel_input_uint8 (device )
287
- _test_distrib_multilabel_input_NHW (device )
288
- _test_distrib_accumulator_device (device )
289
-
290
-
291
- @pytest .mark .tpu
292
- @pytest .mark .skipif ("NUM_TPU_WORKERS" in os .environ , reason = "Skip if NUM_TPU_WORKERS is in env vars" )
293
- @pytest .mark .skipif (not idist .has_xla_support , reason = "Skip if no PyTorch XLA package" )
294
- def test_distrib_single_device_xla ():
295
-
296
- device = idist .device ()
297
- _test_distrib_input_float (device )
298
- _test_distrib_multilabel_input_YCbCr (device )
299
- _test_distrib_multilabel_input_uint8 (device )
300
- _test_distrib_multilabel_input_NHW (device )
301
- _test_distrib_accumulator_device (device )
302
-
303
-
304
- def _test_distrib_xla_nprocs (index ):
305
- device = idist .device ()
306
- _test_distrib_input_float (device )
307
- _test_distrib_multilabel_input_YCbCr (device )
308
- _test_distrib_multilabel_input_uint8 (device )
309
- _test_distrib_multilabel_input_NHW (device )
310
- _test_distrib_accumulator_device (device )
311
-
312
-
313
- @pytest .mark .tpu
314
- @pytest .mark .skipif ("NUM_TPU_WORKERS" not in os .environ , reason = "Skip if no NUM_TPU_WORKERS in env vars" )
315
- @pytest .mark .skipif (not idist .has_xla_support , reason = "Skip if no PyTorch XLA package" )
316
- def test_distrib_xla_nprocs (xmp_executor ):
317
- n = int (os .environ ["NUM_TPU_WORKERS" ])
318
- xmp_executor (_test_distrib_xla_nprocs , args = (), nprocs = n )
319
-
320
-
321
- @pytest .mark .distributed
322
- @pytest .mark .skipif (not idist .has_hvd_support , reason = "Skip if no Horovod dist support" )
323
- @pytest .mark .skipif ("WORLD_SIZE" in os .environ , reason = "Skip if launched as multiproc" )
324
- def test_distrib_hvd (gloo_hvd_executor ):
325
-
326
- device = torch .device ("cpu" if not torch .cuda .is_available () else "cuda" )
327
- nproc = 4 if not torch .cuda .is_available () else torch .cuda .device_count ()
328
-
329
- gloo_hvd_executor (_test_distrib_input_float , (device ,), np = nproc , do_init = True )
330
- gloo_hvd_executor (_test_distrib_multilabel_input_YCbCr , (device ,), np = nproc , do_init = True )
331
- gloo_hvd_executor (_test_distrib_multilabel_input_uint8 , (device ,), np = nproc , do_init = True )
332
- gloo_hvd_executor (_test_distrib_multilabel_input_NHW , (device ,), np = nproc , do_init = True )
333
- gloo_hvd_executor (_test_distrib_accumulator_device , (device ,), np = nproc , do_init = True )
0 commit comments