@@ -17,44 +17,40 @@ def test_no_update():
17
17
mae .compute ()
18
18
19
19
20
- def test_compute ():
20
+ @pytest .fixture (params = [item for item in range (4 )])
21
+ def test_case (request ):
22
+
23
+ return [
24
+ (torch .randint (0 , 10 , size = (100 , 1 )), torch .randint (0 , 10 , size = (100 , 1 )), 1 ),
25
+ (torch .randint (- 10 , 10 , size = (100 , 5 )), torch .randint (- 10 , 10 , size = (100 , 5 )), 1 ),
26
+ # updated batches
27
+ (torch .randint (0 , 10 , size = (100 , 1 )), torch .randint (0 , 10 , size = (100 , 1 )), 16 ),
28
+ (torch .randint (- 20 , 20 , size = (100 , 5 )), torch .randint (- 20 , 20 , size = (100 , 5 )), 16 ),
29
+ ][request .param ]
30
+
31
+
32
+ @pytest .mark .parametrize ("n_times" , range (5 ))
33
+ def test_compute (n_times , test_case ):
21
34
22
35
mae = MeanAbsoluteError ()
23
36
24
- def _test (y_pred , y , batch_size ):
25
- mae .reset ()
26
- if batch_size > 1 :
27
- n_iters = y .shape [0 ] // batch_size + 1
28
- for i in range (n_iters ):
29
- idx = i * batch_size
30
- mae .update ((y_pred [idx : idx + batch_size ], y [idx : idx + batch_size ]))
31
- else :
32
- mae .update ((y_pred , y , batch_size ))
33
-
34
- np_y = y .numpy ()
35
- np_y_pred = y_pred .numpy ()
36
-
37
- np_res = (np .abs (np_y_pred - np_y )).sum () / np_y .shape [0 ]
38
- assert isinstance (mae .compute (), float )
39
- assert mae .compute () == np_res
40
-
41
- def get_test_cases ():
42
-
43
- test_cases = [
44
- (torch .randint (0 , 10 , size = (100 , 1 )), torch .randint (0 , 10 , size = (100 , 1 )), 1 ),
45
- (torch .randint (- 10 , 10 , size = (100 , 5 )), torch .randint (- 10 , 10 , size = (100 , 5 )), 1 ),
46
- # updated batches
47
- (torch .randint (0 , 10 , size = (100 , 1 )), torch .randint (0 , 10 , size = (100 , 1 )), 16 ),
48
- (torch .randint (- 20 , 20 , size = (100 , 5 )), torch .randint (- 20 , 20 , size = (100 , 5 )), 16 ),
49
- ]
50
-
51
- return test_cases
52
-
53
- for _ in range (5 ):
54
- # check multiple random inputs as random exact occurencies are rare
55
- test_cases = get_test_cases ()
56
- for y_pred , y , batch_size in test_cases :
57
- _test (y_pred , y , batch_size )
37
+ y_pred , y , batch_size = test_case
38
+
39
+ mae .reset ()
40
+ if batch_size > 1 :
41
+ n_iters = y .shape [0 ] // batch_size + 1
42
+ for i in range (n_iters ):
43
+ idx = i * batch_size
44
+ mae .update ((y_pred [idx : idx + batch_size ], y [idx : idx + batch_size ]))
45
+ else :
46
+ mae .update ((y_pred , y , batch_size ))
47
+
48
+ np_y = y .numpy ()
49
+ np_y_pred = y_pred .numpy ()
50
+
51
+ np_res = (np .abs (np_y_pred - np_y )).sum () / np_y .shape [0 ]
52
+ assert isinstance (mae .compute (), float )
53
+ assert mae .compute () == np_res
58
54
59
55
60
56
def _test_distrib_integration (device ):
0 commit comments