@@ -62,57 +62,50 @@ def test_binary_wrong_inputs():
62
62
acc .update ((torch .randint (0 , 2 , size = (10 ,)).long (), torch .randint (0 , 2 , size = (10 , 5 , 6 )).long ()))
63
63
64
64
65
- def test_binary_input ():
66
-
65
+ @pytest .fixture (params = [item for item in range (11 )])
66
+ def test_data_binary (request ):
67
+ return [
68
+ # Binary accuracy on input of shape (N, 1) or (N, )
69
+ (torch .randint (0 , 2 , size = (10 ,)).long (), torch .randint (0 , 2 , size = (10 ,)).long (), 1 ),
70
+ (torch .randint (0 , 2 , size = (10 , 1 )).long (), torch .randint (0 , 2 , size = (10 , 1 )).long (), 1 ),
71
+ # updated batches
72
+ (torch .randint (0 , 2 , size = (50 ,)).long (), torch .randint (0 , 2 , size = (50 ,)).long (), 16 ),
73
+ (torch .randint (0 , 2 , size = (50 , 1 )).long (), torch .randint (0 , 2 , size = (50 , 1 )).long (), 16 ),
74
+ # Binary accuracy on input of shape (N, L)
75
+ (torch .randint (0 , 2 , size = (10 , 5 )).long (), torch .randint (0 , 2 , size = (10 , 5 )).long (), 1 ),
76
+ (torch .randint (0 , 2 , size = (10 , 8 )).long (), torch .randint (0 , 2 , size = (10 , 8 )).long (), 1 ),
77
+ # updated batches
78
+ (torch .randint (0 , 2 , size = (50 , 5 )).long (), torch .randint (0 , 2 , size = (50 , 5 )).long (), 16 ),
79
+ (torch .randint (0 , 2 , size = (50 , 8 )).long (), torch .randint (0 , 2 , size = (50 , 8 )).long (), 16 ),
80
+ # Binary accuracy on input of shape (N, H, W, ...)
81
+ (torch .randint (0 , 2 , size = (4 , 1 , 12 , 10 )).long (), torch .randint (0 , 2 , size = (4 , 1 , 12 , 10 )).long (), 1 ),
82
+ (torch .randint (0 , 2 , size = (15 , 1 , 20 , 10 )).long (), torch .randint (0 , 2 , size = (15 , 1 , 20 , 10 )).long (), 1 ),
83
+ # updated batches
84
+ (torch .randint (0 , 2 , size = (50 , 1 , 12 , 10 )).long (), torch .randint (0 , 2 , size = (50 , 1 , 12 , 10 )).long (), 16 ),
85
+ (torch .randint (0 , 2 , size = (50 , 1 , 20 , 10 )).long (), torch .randint (0 , 2 , size = (50 , 1 , 20 , 10 )).long (), 16 ),
86
+ ][request .param ]
87
+
88
+
89
+ @pytest .mark .parametrize ("n_times" , range (5 ))
90
+ def test_binary_input (n_times , test_data_binary ):
67
91
acc = Accuracy ()
68
92
69
- def _test (y_pred , y , batch_size ):
70
- acc .reset ()
71
- if batch_size > 1 :
72
- n_iters = y .shape [0 ] // batch_size + 1
73
- for i in range (n_iters ):
74
- idx = i * batch_size
75
- acc .update ((y_pred [idx : idx + batch_size ], y [idx : idx + batch_size ]))
76
- else :
77
- acc .update ((y_pred , y ))
78
-
79
- np_y = y .numpy ().ravel ()
80
- np_y_pred = y_pred .numpy ().ravel ()
81
-
82
- assert acc ._type == "binary"
83
- assert isinstance (acc .compute (), float )
84
- assert accuracy_score (np_y , np_y_pred ) == pytest .approx (acc .compute ())
85
-
86
- def get_test_cases ():
87
-
88
- test_cases = [
89
- # Binary accuracy on input of shape (N, 1) or (N, )
90
- (torch .randint (0 , 2 , size = (10 ,)).long (), torch .randint (0 , 2 , size = (10 ,)).long (), 1 ),
91
- (torch .randint (0 , 2 , size = (10 , 1 )).long (), torch .randint (0 , 2 , size = (10 , 1 )).long (), 1 ),
92
- # updated batches
93
- (torch .randint (0 , 2 , size = (50 ,)).long (), torch .randint (0 , 2 , size = (50 ,)).long (), 16 ),
94
- (torch .randint (0 , 2 , size = (50 , 1 )).long (), torch .randint (0 , 2 , size = (50 , 1 )).long (), 16 ),
95
- # Binary accuracy on input of shape (N, L)
96
- (torch .randint (0 , 2 , size = (10 , 5 )).long (), torch .randint (0 , 2 , size = (10 , 5 )).long (), 1 ),
97
- (torch .randint (0 , 2 , size = (10 , 8 )).long (), torch .randint (0 , 2 , size = (10 , 8 )).long (), 1 ),
98
- # updated batches
99
- (torch .randint (0 , 2 , size = (50 , 5 )).long (), torch .randint (0 , 2 , size = (50 , 5 )).long (), 16 ),
100
- (torch .randint (0 , 2 , size = (50 , 8 )).long (), torch .randint (0 , 2 , size = (50 , 8 )).long (), 16 ),
101
- # Binary accuracy on input of shape (N, H, W, ...)
102
- (torch .randint (0 , 2 , size = (4 , 1 , 12 , 10 )).long (), torch .randint (0 , 2 , size = (4 , 1 , 12 , 10 )).long (), 1 ),
103
- (torch .randint (0 , 2 , size = (15 , 1 , 20 , 10 )).long (), torch .randint (0 , 2 , size = (15 , 1 , 20 , 10 )).long (), 1 ),
104
- # updated batches
105
- (torch .randint (0 , 2 , size = (50 , 1 , 12 , 10 )).long (), torch .randint (0 , 2 , size = (50 , 1 , 12 , 10 )).long (), 16 ),
106
- (torch .randint (0 , 2 , size = (50 , 1 , 20 , 10 )).long (), torch .randint (0 , 2 , size = (50 , 1 , 20 , 10 )).long (), 16 ),
107
- ]
93
+ y_pred , y , batch_size = test_data_binary
94
+ acc .reset ()
95
+ if batch_size > 1 :
96
+ n_iters = y .shape [0 ] // batch_size + 1
97
+ for i in range (n_iters ):
98
+ idx = i * batch_size
99
+ acc .update ((y_pred [idx : idx + batch_size ], y [idx : idx + batch_size ]))
100
+ else :
101
+ acc .update ((y_pred , y ))
108
102
109
- return test_cases
103
+ np_y = y .numpy ().ravel ()
104
+ np_y_pred = y_pred .numpy ().ravel ()
110
105
111
- for _ in range (5 ):
112
- # check multiple random inputs as random exact occurencies are rare
113
- test_cases = get_test_cases ()
114
- for y_pred , y , n_iters in test_cases :
115
- _test (y_pred , y , n_iters )
106
+ assert acc ._type == "binary"
107
+ assert isinstance (acc .compute (), float )
108
+ assert accuracy_score (np_y , np_y_pred ) == pytest .approx (acc .compute ())
116
109
117
110
118
111
def test_multiclass_wrong_inputs ():
@@ -208,55 +201,49 @@ def test_multilabel_wrong_inputs():
208
201
acc .update ((torch .randint (0 , 2 , size = (10 , 1 )), torch .randint (0 , 2 , size = (10 , 1 )).long ()))
209
202
210
203
211
- def test_multilabel_input ():
204
+ @pytest .fixture (params = [item for item in range (12 )])
205
+ def test_data (request ):
206
+ return [
207
+ # Multilabel input data of shape (N, C) and (N, C)
208
+ (torch .randint (0 , 2 , size = (10 , 4 )).long (), torch .randint (0 , 2 , size = (10 , 4 )).long (), 1 ),
209
+ (torch .randint (0 , 2 , size = (10 , 7 )).long (), torch .randint (0 , 2 , size = (10 , 7 )).long (), 1 ),
210
+ # updated batches
211
+ (torch .randint (0 , 2 , size = (50 , 4 )).long (), torch .randint (0 , 2 , size = (50 , 4 )).long (), 16 ),
212
+ (torch .randint (0 , 2 , size = (50 , 7 )).long (), torch .randint (0 , 2 , size = (50 , 7 )).long (), 16 ),
213
+ # Multilabel input data of shape (N, H, W)
214
+ (torch .randint (0 , 2 , size = (10 , 5 , 10 )).long (), torch .randint (0 , 2 , size = (10 , 5 , 10 )).long (), 1 ),
215
+ (torch .randint (0 , 2 , size = (10 , 4 , 10 )).long (), torch .randint (0 , 2 , size = (10 , 4 , 10 )).long (), 1 ),
216
+ # updated batches
217
+ (torch .randint (0 , 2 , size = (50 , 5 , 10 )).long (), torch .randint (0 , 2 , size = (50 , 5 , 10 )).long (), 16 ),
218
+ (torch .randint (0 , 2 , size = (50 , 4 , 10 )).long (), torch .randint (0 , 2 , size = (50 , 4 , 10 )).long (), 16 ),
219
+ # Multilabel input data of shape (N, C, H, W, ...) and (N, C, H, W, ...)
220
+ (torch .randint (0 , 2 , size = (4 , 5 , 12 , 10 )).long (), torch .randint (0 , 2 , size = (4 , 5 , 12 , 10 )).long (), 1 ),
221
+ (torch .randint (0 , 2 , size = (4 , 10 , 12 , 8 )).long (), torch .randint (0 , 2 , size = (4 , 10 , 12 , 8 )).long (), 1 ),
222
+ # updated batches
223
+ (torch .randint (0 , 2 , size = (50 , 5 , 12 , 10 )).long (), torch .randint (0 , 2 , size = (50 , 5 , 12 , 10 )).long (), 16 ),
224
+ (torch .randint (0 , 2 , size = (50 , 10 , 12 , 8 )).long (), torch .randint (0 , 2 , size = (50 , 10 , 12 , 8 )).long (), 16 ),
225
+ ][request .param ]
226
+
227
+
228
+ @pytest .mark .parametrize ("n_times" , range (5 ))
229
+ def test_multilabel_input (n_times , test_data ):
212
230
acc = Accuracy (is_multilabel = True )
213
231
214
- def _test (y_pred , y , batch_size ):
215
- acc .reset ()
216
- if batch_size > 1 :
217
- n_iters = y .shape [0 ] // batch_size + 1
218
- for i in range (n_iters ):
219
- idx = i * batch_size
220
- acc .update ((y_pred [idx : idx + batch_size ], y [idx : idx + batch_size ]))
221
- else :
222
- acc .update ((y_pred , y ))
223
-
224
- np_y_pred = to_numpy_multilabel (y_pred )
225
- np_y = to_numpy_multilabel (y )
226
-
227
- assert acc ._type == "multilabel"
228
- assert isinstance (acc .compute (), float )
229
- assert accuracy_score (np_y , np_y_pred ) == pytest .approx (acc .compute ())
230
-
231
- def get_test_cases ():
232
+ y_pred , y , batch_size = test_data
233
+ if batch_size > 1 :
234
+ n_iters = y .shape [0 ] // batch_size + 1
235
+ for i in range (n_iters ):
236
+ idx = i * batch_size
237
+ acc .update ((y_pred [idx : idx + batch_size ], y [idx : idx + batch_size ]))
238
+ else :
239
+ acc .update ((y_pred , y ))
232
240
233
- test_cases = [
234
- # Multilabel input data of shape (N, C) and (N, C)
235
- (torch .randint (0 , 2 , size = (10 , 4 )).long (), torch .randint (0 , 2 , size = (10 , 4 )).long (), 1 ),
236
- (torch .randint (0 , 2 , size = (10 , 7 )).long (), torch .randint (0 , 2 , size = (10 , 7 )).long (), 1 ),
237
- # updated batches
238
- (torch .randint (0 , 2 , size = (50 , 4 )).long (), torch .randint (0 , 2 , size = (50 , 4 )).long (), 16 ),
239
- (torch .randint (0 , 2 , size = (50 , 7 )).long (), torch .randint (0 , 2 , size = (50 , 7 )).long (), 16 ),
240
- # Multilabel input data of shape (N, H, W)
241
- (torch .randint (0 , 2 , size = (10 , 5 , 10 )).long (), torch .randint (0 , 2 , size = (10 , 5 , 10 )).long (), 1 ),
242
- (torch .randint (0 , 2 , size = (10 , 4 , 10 )).long (), torch .randint (0 , 2 , size = (10 , 4 , 10 )).long (), 1 ),
243
- # updated batches
244
- (torch .randint (0 , 2 , size = (50 , 5 , 10 )).long (), torch .randint (0 , 2 , size = (50 , 5 , 10 )).long (), 16 ),
245
- (torch .randint (0 , 2 , size = (50 , 4 , 10 )).long (), torch .randint (0 , 2 , size = (50 , 4 , 10 )).long (), 16 ),
246
- # Multilabel input data of shape (N, C, H, W, ...) and (N, C, H, W, ...)
247
- (torch .randint (0 , 2 , size = (4 , 5 , 12 , 10 )).long (), torch .randint (0 , 2 , size = (4 , 5 , 12 , 10 )).long (), 1 ),
248
- (torch .randint (0 , 2 , size = (4 , 10 , 12 , 8 )).long (), torch .randint (0 , 2 , size = (4 , 10 , 12 , 8 )).long (), 1 ),
249
- # updated batches
250
- (torch .randint (0 , 2 , size = (50 , 5 , 12 , 10 )).long (), torch .randint (0 , 2 , size = (50 , 5 , 12 , 10 )).long (), 16 ),
251
- (torch .randint (0 , 2 , size = (50 , 10 , 12 , 8 )).long (), torch .randint (0 , 2 , size = (50 , 10 , 12 , 8 )).long (), 16 ),
252
- ]
253
- return test_cases
241
+ np_y_pred = to_numpy_multilabel (y_pred )
242
+ np_y = to_numpy_multilabel (y )
254
243
255
- for _ in range (5 ):
256
- # check multiple random inputs as random exact occurencies are rare
257
- test_cases = get_test_cases ()
258
- for y_pred , y , batch_size in test_cases :
259
- _test (y_pred , y , batch_size )
244
+ assert acc ._type == "multilabel"
245
+ assert isinstance (acc .compute (), float )
246
+ assert accuracy_score (np_y , np_y_pred ) == pytest .approx (acc .compute ())
260
247
261
248
262
249
def test_incorrect_type ():
0 commit comments