@@ -17,42 +17,38 @@ def test_zero_sample():
17
17
mpd .compute ()
18
18
19
19
20
- def test_compute ():
20
+ @pytest .fixture (params = [item for item in range (4 )])
21
+ def test_case (request ):
22
+
23
+ return [
24
+ (torch .randint (0 , 10 , size = (100 , 1 )), torch .randint (0 , 10 , size = (100 , 1 )), 1 ),
25
+ (torch .randint (- 20 , 20 , size = (100 , 5 )), torch .randint (- 20 , 20 , size = (100 , 5 )), 1 ),
26
+ # updated batches
27
+ (torch .randint (0 , 10 , size = (100 , 1 )), torch .randint (0 , 10 , size = (100 , 1 )), 16 ),
28
+ (torch .randint (- 20 , 20 , size = (100 , 5 )), torch .randint (- 20 , 20 , size = (100 , 5 )), 16 ),
29
+ ][request .param ]
30
+
31
+
32
+ @pytest .mark .parametrize ("n_times" , range (5 ))
33
+ def test_compute (n_times , test_case ):
21
34
22
35
mpd = MeanPairwiseDistance ()
23
36
24
- def _test (y_pred , y , batch_size ):
25
- mpd .reset ()
26
- if batch_size > 1 :
27
- n_iters = y .shape [0 ] // batch_size + 1
28
- for i in range (n_iters ):
29
- idx = i * batch_size
30
- mpd .update ((y_pred [idx : idx + batch_size ], y [idx : idx + batch_size ]))
31
- else :
32
- mpd .update ((y_pred , y ))
33
-
34
- np_res = np .mean (torch .pairwise_distance (y_pred , y , p = mpd ._p , eps = mpd ._eps ).numpy ())
35
-
36
- assert isinstance (mpd .compute (), float )
37
- assert pytest .approx (mpd .compute ()) == np_res
38
-
39
- def get_test_cases ():
40
-
41
- test_cases = [
42
- (torch .randint (0 , 10 , size = (100 , 1 )), torch .randint (0 , 10 , size = (100 , 1 )), 1 ),
43
- (torch .randint (- 20 , 20 , size = (100 , 5 )), torch .randint (- 20 , 20 , size = (100 , 5 )), 1 ),
44
- # updated batches
45
- (torch .randint (0 , 10 , size = (100 , 1 )), torch .randint (0 , 10 , size = (100 , 1 )), 16 ),
46
- (torch .randint (- 20 , 20 , size = (100 , 5 )), torch .randint (- 20 , 20 , size = (100 , 5 )), 16 ),
47
- ]
48
-
49
- return test_cases
50
-
51
- for _ in range (5 ):
52
- # check multiple random inputs as random exact occurencies are rare
53
- test_cases = get_test_cases ()
54
- for y_pred , y , batch_size in test_cases :
55
- _test (y_pred , y , batch_size )
37
+ y_pred , y , batch_size = test_case
38
+
39
+ mpd .reset ()
40
+ if batch_size > 1 :
41
+ n_iters = y .shape [0 ] // batch_size + 1
42
+ for i in range (n_iters ):
43
+ idx = i * batch_size
44
+ mpd .update ((y_pred [idx : idx + batch_size ], y [idx : idx + batch_size ]))
45
+ else :
46
+ mpd .update ((y_pred , y ))
47
+
48
+ np_res = np .mean (torch .pairwise_distance (y_pred , y , p = mpd ._p , eps = mpd ._eps ).numpy ())
49
+
50
+ assert isinstance (mpd .compute (), float )
51
+ assert pytest .approx (mpd .compute ()) == np_res
56
52
57
53
58
54
def _test_distrib_integration (device ):
0 commit comments