9
9
from ignite .metrics import PSNR
10
10
from ignite .utils import manual_seed
11
11
12
- from tests .ignite import cpu_and_maybe_cuda
13
-
14
12
15
13
def test_zero_div ():
16
14
psnr = PSNR (1.0 )
@@ -30,35 +28,31 @@ def test_invalid_psnr():
30
28
psnr .update ((y_pred , y .squeeze (dim = 0 )))
31
29
32
30
33
- @pytest .fixture
34
- def test_data (request ):
31
+ @pytest .fixture ( params = [ "float" , "YCbCr" , "uint8" , "NHW shape" ])
32
+ def test_data (request , available_device ):
35
33
manual_seed (42 )
36
- device = request .param .device
37
- if request .param .sample_type == "float" :
38
- y_pred = torch .rand (8 , 3 , 28 , 28 , device = device )
34
+ if request .param == "float" :
35
+ y_pred = torch .rand (8 , 3 , 28 , 28 , device = available_device )
39
36
y = y_pred * 0.8
40
- elif request .param . sample_type == "YCbCr" :
41
- y_pred = torch .randint (16 , 236 , (4 , 1 , 12 , 12 ), dtype = torch .uint8 , device = device )
42
- y = torch .randint (16 , 236 , (4 , 1 , 12 , 12 ), dtype = torch .uint8 , device = device )
43
- elif request .param . sample_type == "uint8" :
44
- y_pred = torch .randint (0 , 256 , (4 , 3 , 16 , 16 ), dtype = torch .uint8 , device = device )
37
+ elif request .param == "YCbCr" :
38
+ y_pred = torch .randint (16 , 236 , (4 , 1 , 12 , 12 ), dtype = torch .uint8 , device = available_device )
39
+ y = torch .randint (16 , 236 , (4 , 1 , 12 , 12 ), dtype = torch .uint8 , device = available_device )
40
+ elif request .param == "uint8" :
41
+ y_pred = torch .randint (0 , 256 , (4 , 3 , 16 , 16 ), dtype = torch .uint8 , device = available_device )
45
42
y = (y_pred * 0.8 ).to (torch .uint8 )
46
- elif request .param . sample_type == "NHW shape" :
47
- y_pred = torch .rand (8 , 28 , 28 , device = device )
43
+ elif request .param == "NHW shape" :
44
+ y_pred = torch .rand (8 , 28 , 28 , device = available_device )
48
45
y = y_pred * 0.8
49
46
else :
50
47
raise ValueError (f"Wrong fixture parameter, given { request .param } " )
51
48
return (y_pred , y )
52
49
53
50
54
- @pytest .mark .parametrize ("device" , cpu_and_maybe_cuda (), indirect = True )
55
- @pytest .mark .parametrize ("sample_type" , ["float" , "YCbCr" , "uint8" , "NHW shape" ], indirect = True )
56
- def test_psnr (test_data ):
51
+ def test_psnr (test_data , available_device ):
57
52
y_pred , y = test_data
58
- device = idist .device ()
59
53
data_range = (y .max () - y .min ()).cpu ().item ()
60
54
61
- psnr = PSNR (data_range = data_range , device = device )
55
+ psnr = PSNR (data_range = data_range , device = available_device )
62
56
psnr .update (test_data )
63
57
psnr_compute = psnr .compute ()
64
58
0 commit comments