@@ -62,57 +62,50 @@ def test_binary_wrong_inputs():
62
62
acc .update ((torch .randint (0 , 2 , size = (10 ,)).long (), torch .randint (0 , 2 , size = (10 , 5 , 6 )).long ()))
63
63
64
64
65
- def test_binary_input ():
66
-
65
+ @pytest .fixture (params = range (12 ))
66
+ def test_data_binary (request ):
67
+ return [
68
+ # Binary accuracy on input of shape (N, 1) or (N, )
69
+ (torch .randint (0 , 2 , size = (10 ,)).long (), torch .randint (0 , 2 , size = (10 ,)).long (), 1 ),
70
+ (torch .randint (0 , 2 , size = (10 , 1 )).long (), torch .randint (0 , 2 , size = (10 , 1 )).long (), 1 ),
71
+ # updated batches
72
+ (torch .randint (0 , 2 , size = (50 ,)).long (), torch .randint (0 , 2 , size = (50 ,)).long (), 16 ),
73
+ (torch .randint (0 , 2 , size = (50 , 1 )).long (), torch .randint (0 , 2 , size = (50 , 1 )).long (), 16 ),
74
+ # Binary accuracy on input of shape (N, L)
75
+ (torch .randint (0 , 2 , size = (10 , 5 )).long (), torch .randint (0 , 2 , size = (10 , 5 )).long (), 1 ),
76
+ (torch .randint (0 , 2 , size = (10 , 8 )).long (), torch .randint (0 , 2 , size = (10 , 8 )).long (), 1 ),
77
+ # updated batches
78
+ (torch .randint (0 , 2 , size = (50 , 5 )).long (), torch .randint (0 , 2 , size = (50 , 5 )).long (), 16 ),
79
+ (torch .randint (0 , 2 , size = (50 , 8 )).long (), torch .randint (0 , 2 , size = (50 , 8 )).long (), 16 ),
80
+ # Binary accuracy on input of shape (N, H, W, ...)
81
+ (torch .randint (0 , 2 , size = (4 , 1 , 12 , 10 )).long (), torch .randint (0 , 2 , size = (4 , 1 , 12 , 10 )).long (), 1 ),
82
+ (torch .randint (0 , 2 , size = (15 , 1 , 20 , 10 )).long (), torch .randint (0 , 2 , size = (15 , 1 , 20 , 10 )).long (), 1 ),
83
+ # updated batches
84
+ (torch .randint (0 , 2 , size = (50 , 1 , 12 , 10 )).long (), torch .randint (0 , 2 , size = (50 , 1 , 12 , 10 )).long (), 16 ),
85
+ (torch .randint (0 , 2 , size = (50 , 1 , 20 , 10 )).long (), torch .randint (0 , 2 , size = (50 , 1 , 20 , 10 )).long (), 16 ),
86
+ ][request .param ]
87
+
88
+
89
+ @pytest .mark .parametrize ("n_times" , range (5 ))
90
+ def test_binary_input (n_times , test_data_binary ):
67
91
acc = Accuracy ()
68
92
69
- def _test (y_pred , y , batch_size ):
70
- acc .reset ()
71
- if batch_size > 1 :
72
- n_iters = y .shape [0 ] // batch_size + 1
73
- for i in range (n_iters ):
74
- idx = i * batch_size
75
- acc .update ((y_pred [idx : idx + batch_size ], y [idx : idx + batch_size ]))
76
- else :
77
- acc .update ((y_pred , y ))
78
-
79
- np_y = y .numpy ().ravel ()
80
- np_y_pred = y_pred .numpy ().ravel ()
81
-
82
- assert acc ._type == "binary"
83
- assert isinstance (acc .compute (), float )
84
- assert accuracy_score (np_y , np_y_pred ) == pytest .approx (acc .compute ())
85
-
86
- def get_test_cases ():
87
-
88
- test_cases = [
89
- # Binary accuracy on input of shape (N, 1) or (N, )
90
- (torch .randint (0 , 2 , size = (10 ,)).long (), torch .randint (0 , 2 , size = (10 ,)).long (), 1 ),
91
- (torch .randint (0 , 2 , size = (10 , 1 )).long (), torch .randint (0 , 2 , size = (10 , 1 )).long (), 1 ),
92
- # updated batches
93
- (torch .randint (0 , 2 , size = (50 ,)).long (), torch .randint (0 , 2 , size = (50 ,)).long (), 16 ),
94
- (torch .randint (0 , 2 , size = (50 , 1 )).long (), torch .randint (0 , 2 , size = (50 , 1 )).long (), 16 ),
95
- # Binary accuracy on input of shape (N, L)
96
- (torch .randint (0 , 2 , size = (10 , 5 )).long (), torch .randint (0 , 2 , size = (10 , 5 )).long (), 1 ),
97
- (torch .randint (0 , 2 , size = (10 , 8 )).long (), torch .randint (0 , 2 , size = (10 , 8 )).long (), 1 ),
98
- # updated batches
99
- (torch .randint (0 , 2 , size = (50 , 5 )).long (), torch .randint (0 , 2 , size = (50 , 5 )).long (), 16 ),
100
- (torch .randint (0 , 2 , size = (50 , 8 )).long (), torch .randint (0 , 2 , size = (50 , 8 )).long (), 16 ),
101
- # Binary accuracy on input of shape (N, H, W, ...)
102
- (torch .randint (0 , 2 , size = (4 , 1 , 12 , 10 )).long (), torch .randint (0 , 2 , size = (4 , 1 , 12 , 10 )).long (), 1 ),
103
- (torch .randint (0 , 2 , size = (15 , 1 , 20 , 10 )).long (), torch .randint (0 , 2 , size = (15 , 1 , 20 , 10 )).long (), 1 ),
104
- # updated batches
105
- (torch .randint (0 , 2 , size = (50 , 1 , 12 , 10 )).long (), torch .randint (0 , 2 , size = (50 , 1 , 12 , 10 )).long (), 16 ),
106
- (torch .randint (0 , 2 , size = (50 , 1 , 20 , 10 )).long (), torch .randint (0 , 2 , size = (50 , 1 , 20 , 10 )).long (), 16 ),
107
- ]
108
-
109
- return test_cases
110
-
111
- for _ in range (5 ):
112
- # check multiple random inputs as random exact occurencies are rare
113
- test_cases = get_test_cases ()
114
- for y_pred , y , n_iters in test_cases :
115
- _test (y_pred , y , n_iters )
93
+ y_pred , y , batch_size = test_data_binary
94
+ acc .reset ()
95
+ if batch_size > 1 :
96
+ n_iters = y .shape [0 ] // batch_size + 1
97
+ for i in range (n_iters ):
98
+ idx = i * batch_size
99
+ acc .update ((y_pred [idx : idx + batch_size ], y [idx : idx + batch_size ]))
100
+ else :
101
+ acc .update ((y_pred , y ))
102
+
103
+ np_y = y .numpy ().ravel ()
104
+ np_y_pred = y_pred .numpy ().ravel ()
105
+
106
+ assert acc ._type == "binary"
107
+ assert isinstance (acc .compute (), float )
108
+ assert accuracy_score (np_y , np_y_pred ) == pytest .approx (acc .compute ())
116
109
117
110
118
111
def test_multiclass_wrong_inputs ():
@@ -131,53 +124,48 @@ def test_multiclass_wrong_inputs():
131
124
acc .update ((torch .rand (10 ), torch .randint (0 , 5 , size = (10 , 5 , 6 )).long ()))
132
125
133
126
134
- def test_multiclass_input ():
127
+ @pytest .fixture (params = range (11 ))
128
+ def test_data_multiclass (request ):
129
+ return [
130
+ # Multiclass input data of shape (N, ) and (N, C)
131
+ (torch .rand (10 , 4 ), torch .randint (0 , 4 , size = (10 ,)).long (), 1 ),
132
+ (torch .rand (10 , 10 , 1 ), torch .randint (0 , 18 , size = (10 , 1 )).long (), 1 ),
133
+ (torch .rand (10 , 18 ), torch .randint (0 , 18 , size = (10 ,)).long (), 1 ),
134
+ (torch .rand (4 , 10 ), torch .randint (0 , 10 , size = (4 ,)).long (), 1 ),
135
+ # 2-classes
136
+ (torch .rand (4 , 2 ), torch .randint (0 , 2 , size = (4 ,)).long (), 1 ),
137
+ (torch .rand (100 , 5 ), torch .randint (0 , 5 , size = (100 ,)).long (), 16 ),
138
+ # Multiclass input data of shape (N, L) and (N, C, L)
139
+ (torch .rand (10 , 4 , 5 ), torch .randint (0 , 4 , size = (10 , 5 )).long (), 1 ),
140
+ (torch .rand (4 , 10 , 5 ), torch .randint (0 , 10 , size = (4 , 5 )).long (), 1 ),
141
+ (torch .rand (100 , 9 , 7 ), torch .randint (0 , 9 , size = (100 , 7 )).long (), 16 ),
142
+ # Multiclass input data of shape (N, H, W, ...) and (N, C, H, W, ...)
143
+ (torch .rand (4 , 5 , 12 , 10 ), torch .randint (0 , 5 , size = (4 , 12 , 10 )).long (), 1 ),
144
+ (torch .rand (100 , 3 , 8 , 8 ), torch .randint (0 , 3 , size = (100 , 8 , 8 )).long (), 16 ),
145
+ ][request .param ]
146
+
147
+
148
+ @pytest .mark .parametrize ("n_times" , range (5 ))
149
+ def test_multiclass_input (n_times , test_data_multiclass ):
135
150
acc = Accuracy ()
136
151
137
- def _test (y_pred , y , batch_size ):
138
- acc .reset ()
139
- if batch_size > 1 :
140
- # Batched Updates
141
- n_iters = y .shape [0 ] // batch_size + 1
142
- for i in range (n_iters ):
143
- idx = i * batch_size
144
- acc .update ((y_pred [idx : idx + batch_size ], y [idx : idx + batch_size ]))
145
- else :
146
- acc .update ((y_pred , y ))
147
-
148
- np_y_pred = y_pred .numpy ().argmax (axis = 1 ).ravel ()
149
- np_y = y .numpy ().ravel ()
150
-
151
- assert acc ._type == "multiclass"
152
- assert isinstance (acc .compute (), float )
153
- assert accuracy_score (np_y , np_y_pred ) == pytest .approx (acc .compute ())
154
-
155
- def get_test_cases ():
156
-
157
- test_cases = [
158
- # Multiclass input data of shape (N, ) and (N, C)
159
- (torch .rand (10 , 4 ), torch .randint (0 , 4 , size = (10 ,)).long (), 1 ),
160
- (torch .rand (10 , 10 , 1 ), torch .randint (0 , 18 , size = (10 , 1 )).long (), 1 ),
161
- (torch .rand (10 , 18 ), torch .randint (0 , 18 , size = (10 ,)).long (), 1 ),
162
- (torch .rand (4 , 10 ), torch .randint (0 , 10 , size = (4 ,)).long (), 1 ),
163
- # 2-classes
164
- (torch .rand (4 , 2 ), torch .randint (0 , 2 , size = (4 ,)).long (), 1 ),
165
- (torch .rand (100 , 5 ), torch .randint (0 , 5 , size = (100 ,)).long (), 16 ),
166
- # Multiclass input data of shape (N, L) and (N, C, L)
167
- (torch .rand (10 , 4 , 5 ), torch .randint (0 , 4 , size = (10 , 5 )).long (), 1 ),
168
- (torch .rand (4 , 10 , 5 ), torch .randint (0 , 10 , size = (4 , 5 )).long (), 1 ),
169
- (torch .rand (100 , 9 , 7 ), torch .randint (0 , 9 , size = (100 , 7 )).long (), 16 ),
170
- # Multiclass input data of shape (N, H, W, ...) and (N, C, H, W, ...)
171
- (torch .rand (4 , 5 , 12 , 10 ), torch .randint (0 , 5 , size = (4 , 12 , 10 )).long (), 1 ),
172
- (torch .rand (100 , 3 , 8 , 8 ), torch .randint (0 , 3 , size = (100 , 8 , 8 )).long (), 16 ),
173
- ]
174
- return test_cases
175
-
176
- for _ in range (5 ):
177
- # check multiple random inputs as random exact occurencies are rare
178
- test_cases = get_test_cases ()
179
- for y_pred , y , batch_size in test_cases :
180
- _test (y_pred , y , batch_size )
152
+ y_pred , y , batch_size = test_data_multiclass
153
+ acc .reset ()
154
+ if batch_size > 1 :
155
+ # Batched Updates
156
+ n_iters = y .shape [0 ] // batch_size + 1
157
+ for i in range (n_iters ):
158
+ idx = i * batch_size
159
+ acc .update ((y_pred [idx : idx + batch_size ], y [idx : idx + batch_size ]))
160
+ else :
161
+ acc .update ((y_pred , y ))
162
+
163
+ np_y_pred = y_pred .numpy ().argmax (axis = 1 ).ravel ()
164
+ np_y = y .numpy ().ravel ()
165
+
166
+ assert acc ._type == "multiclass"
167
+ assert isinstance (acc .compute (), float )
168
+ assert accuracy_score (np_y , np_y_pred ) == pytest .approx (acc .compute ())
181
169
182
170
183
171
def to_numpy_multilabel (y ):
@@ -208,55 +196,49 @@ def test_multilabel_wrong_inputs():
208
196
acc .update ((torch .randint (0 , 2 , size = (10 , 1 )), torch .randint (0 , 2 , size = (10 , 1 )).long ()))
209
197
210
198
211
- def test_multilabel_input ():
199
+ @pytest .fixture (params = range (12 ))
200
+ def test_data_multilabel (request ):
201
+ return [
202
+ # Multilabel input data of shape (N, C) and (N, C)
203
+ (torch .randint (0 , 2 , size = (10 , 4 )).long (), torch .randint (0 , 2 , size = (10 , 4 )).long (), 1 ),
204
+ (torch .randint (0 , 2 , size = (10 , 7 )).long (), torch .randint (0 , 2 , size = (10 , 7 )).long (), 1 ),
205
+ # updated batches
206
+ (torch .randint (0 , 2 , size = (50 , 4 )).long (), torch .randint (0 , 2 , size = (50 , 4 )).long (), 16 ),
207
+ (torch .randint (0 , 2 , size = (50 , 7 )).long (), torch .randint (0 , 2 , size = (50 , 7 )).long (), 16 ),
208
+ # Multilabel input data of shape (N, H, W)
209
+ (torch .randint (0 , 2 , size = (10 , 5 , 10 )).long (), torch .randint (0 , 2 , size = (10 , 5 , 10 )).long (), 1 ),
210
+ (torch .randint (0 , 2 , size = (10 , 4 , 10 )).long (), torch .randint (0 , 2 , size = (10 , 4 , 10 )).long (), 1 ),
211
+ # updated batches
212
+ (torch .randint (0 , 2 , size = (50 , 5 , 10 )).long (), torch .randint (0 , 2 , size = (50 , 5 , 10 )).long (), 16 ),
213
+ (torch .randint (0 , 2 , size = (50 , 4 , 10 )).long (), torch .randint (0 , 2 , size = (50 , 4 , 10 )).long (), 16 ),
214
+ # Multilabel input data of shape (N, C, H, W, ...) and (N, C, H, W, ...)
215
+ (torch .randint (0 , 2 , size = (4 , 5 , 12 , 10 )).long (), torch .randint (0 , 2 , size = (4 , 5 , 12 , 10 )).long (), 1 ),
216
+ (torch .randint (0 , 2 , size = (4 , 10 , 12 , 8 )).long (), torch .randint (0 , 2 , size = (4 , 10 , 12 , 8 )).long (), 1 ),
217
+ # updated batches
218
+ (torch .randint (0 , 2 , size = (50 , 5 , 12 , 10 )).long (), torch .randint (0 , 2 , size = (50 , 5 , 12 , 10 )).long (), 16 ),
219
+ (torch .randint (0 , 2 , size = (50 , 10 , 12 , 8 )).long (), torch .randint (0 , 2 , size = (50 , 10 , 12 , 8 )).long (), 16 ),
220
+ ][request .param ]
221
+
222
+
223
+ @pytest .mark .parametrize ("n_times" , range (5 ))
224
+ def test_multilabel_input (n_times , test_data_multilabel ):
212
225
acc = Accuracy (is_multilabel = True )
213
226
214
- def _test (y_pred , y , batch_size ):
215
- acc .reset ()
216
- if batch_size > 1 :
217
- n_iters = y .shape [0 ] // batch_size + 1
218
- for i in range (n_iters ):
219
- idx = i * batch_size
220
- acc .update ((y_pred [idx : idx + batch_size ], y [idx : idx + batch_size ]))
221
- else :
222
- acc .update ((y_pred , y ))
227
+ y_pred , y , batch_size = test_data_multilabel
228
+ if batch_size > 1 :
229
+ n_iters = y .shape [0 ] // batch_size + 1
230
+ for i in range (n_iters ):
231
+ idx = i * batch_size
232
+ acc .update ((y_pred [idx : idx + batch_size ], y [idx : idx + batch_size ]))
233
+ else :
234
+ acc .update ((y_pred , y ))
223
235
224
- np_y_pred = to_numpy_multilabel (y_pred )
225
- np_y = to_numpy_multilabel (y )
236
+ np_y_pred = to_numpy_multilabel (y_pred )
237
+ np_y = to_numpy_multilabel (y )
226
238
227
- assert acc ._type == "multilabel"
228
- assert isinstance (acc .compute (), float )
229
- assert accuracy_score (np_y , np_y_pred ) == pytest .approx (acc .compute ())
230
-
231
- def get_test_cases ():
232
-
233
- test_cases = [
234
- # Multilabel input data of shape (N, C) and (N, C)
235
- (torch .randint (0 , 2 , size = (10 , 4 )).long (), torch .randint (0 , 2 , size = (10 , 4 )).long (), 1 ),
236
- (torch .randint (0 , 2 , size = (10 , 7 )).long (), torch .randint (0 , 2 , size = (10 , 7 )).long (), 1 ),
237
- # updated batches
238
- (torch .randint (0 , 2 , size = (50 , 4 )).long (), torch .randint (0 , 2 , size = (50 , 4 )).long (), 16 ),
239
- (torch .randint (0 , 2 , size = (50 , 7 )).long (), torch .randint (0 , 2 , size = (50 , 7 )).long (), 16 ),
240
- # Multilabel input data of shape (N, H, W)
241
- (torch .randint (0 , 2 , size = (10 , 5 , 10 )).long (), torch .randint (0 , 2 , size = (10 , 5 , 10 )).long (), 1 ),
242
- (torch .randint (0 , 2 , size = (10 , 4 , 10 )).long (), torch .randint (0 , 2 , size = (10 , 4 , 10 )).long (), 1 ),
243
- # updated batches
244
- (torch .randint (0 , 2 , size = (50 , 5 , 10 )).long (), torch .randint (0 , 2 , size = (50 , 5 , 10 )).long (), 16 ),
245
- (torch .randint (0 , 2 , size = (50 , 4 , 10 )).long (), torch .randint (0 , 2 , size = (50 , 4 , 10 )).long (), 16 ),
246
- # Multilabel input data of shape (N, C, H, W, ...) and (N, C, H, W, ...)
247
- (torch .randint (0 , 2 , size = (4 , 5 , 12 , 10 )).long (), torch .randint (0 , 2 , size = (4 , 5 , 12 , 10 )).long (), 1 ),
248
- (torch .randint (0 , 2 , size = (4 , 10 , 12 , 8 )).long (), torch .randint (0 , 2 , size = (4 , 10 , 12 , 8 )).long (), 1 ),
249
- # updated batches
250
- (torch .randint (0 , 2 , size = (50 , 5 , 12 , 10 )).long (), torch .randint (0 , 2 , size = (50 , 5 , 12 , 10 )).long (), 16 ),
251
- (torch .randint (0 , 2 , size = (50 , 10 , 12 , 8 )).long (), torch .randint (0 , 2 , size = (50 , 10 , 12 , 8 )).long (), 16 ),
252
- ]
253
- return test_cases
254
-
255
- for _ in range (5 ):
256
- # check multiple random inputs as random exact occurencies are rare
257
- test_cases = get_test_cases ()
258
- for y_pred , y , batch_size in test_cases :
259
- _test (y_pred , y , batch_size )
239
+ assert acc ._type == "multilabel"
240
+ assert isinstance (acc .compute (), float )
241
+ assert accuracy_score (np_y , np_y_pred ) == pytest .approx (acc .compute ())
260
242
261
243
262
244
def test_incorrect_type ():
0 commit comments