@@ -62,7 +62,7 @@ def test_binary_wrong_inputs():
62
62
acc .update ((torch .randint (0 , 2 , size = (10 ,)).long (), torch .randint (0 , 2 , size = (10 , 5 , 6 )).long ()))
63
63
64
64
65
- @pytest .fixture (params = [item for item in range (11 )])
65
+ @pytest .fixture (params = [item for item in range (12 )])
66
66
def test_data_binary (request ):
67
67
return [
68
68
# Binary accuracy on input of shape (N, 1) or (N, )
@@ -124,53 +124,48 @@ def test_multiclass_wrong_inputs():
124
124
acc .update ((torch .rand (10 ), torch .randint (0 , 5 , size = (10 , 5 , 6 )).long ()))
125
125
126
126
127
- def test_multiclass_input ():
127
+ @pytest .fixture (params = [item for item in range (11 )])
128
+ def test_data_multiclass (request ):
129
+ return [
130
+ # Multiclass input data of shape (N, ) and (N, C)
131
+ (torch .rand (10 , 4 ), torch .randint (0 , 4 , size = (10 ,)).long (), 1 ),
132
+ (torch .rand (10 , 10 , 1 ), torch .randint (0 , 18 , size = (10 , 1 )).long (), 1 ),
133
+ (torch .rand (10 , 18 ), torch .randint (0 , 18 , size = (10 ,)).long (), 1 ),
134
+ (torch .rand (4 , 10 ), torch .randint (0 , 10 , size = (4 ,)).long (), 1 ),
135
+ # 2-classes
136
+ (torch .rand (4 , 2 ), torch .randint (0 , 2 , size = (4 ,)).long (), 1 ),
137
+ (torch .rand (100 , 5 ), torch .randint (0 , 5 , size = (100 ,)).long (), 16 ),
138
+ # Multiclass input data of shape (N, L) and (N, C, L)
139
+ (torch .rand (10 , 4 , 5 ), torch .randint (0 , 4 , size = (10 , 5 )).long (), 1 ),
140
+ (torch .rand (4 , 10 , 5 ), torch .randint (0 , 10 , size = (4 , 5 )).long (), 1 ),
141
+ (torch .rand (100 , 9 , 7 ), torch .randint (0 , 9 , size = (100 , 7 )).long (), 16 ),
142
+ # Multiclass input data of shape (N, H, W, ...) and (N, C, H, W, ...)
143
+ (torch .rand (4 , 5 , 12 , 10 ), torch .randint (0 , 5 , size = (4 , 12 , 10 )).long (), 1 ),
144
+ (torch .rand (100 , 3 , 8 , 8 ), torch .randint (0 , 3 , size = (100 , 8 , 8 )).long (), 16 ),
145
+ ][request .param ]
146
+
147
+
148
+ @pytest .mark .parametrize ("n_times" , range (5 ))
149
+ def test_multiclass_input (n_times , test_data_multiclass ):
128
150
acc = Accuracy ()
129
151
130
- def _test (y_pred , y , batch_size ):
131
- acc .reset ()
132
- if batch_size > 1 :
133
- # Batched Updates
134
- n_iters = y .shape [0 ] // batch_size + 1
135
- for i in range (n_iters ):
136
- idx = i * batch_size
137
- acc .update ((y_pred [idx : idx + batch_size ], y [idx : idx + batch_size ]))
138
- else :
139
- acc .update ((y_pred , y ))
140
-
141
- np_y_pred = y_pred .numpy ().argmax (axis = 1 ).ravel ()
142
- np_y = y .numpy ().ravel ()
143
-
144
- assert acc ._type == "multiclass"
145
- assert isinstance (acc .compute (), float )
146
- assert accuracy_score (np_y , np_y_pred ) == pytest .approx (acc .compute ())
147
-
148
- def get_test_cases ():
149
-
150
- test_cases = [
151
- # Multiclass input data of shape (N, ) and (N, C)
152
- (torch .rand (10 , 4 ), torch .randint (0 , 4 , size = (10 ,)).long (), 1 ),
153
- (torch .rand (10 , 10 , 1 ), torch .randint (0 , 18 , size = (10 , 1 )).long (), 1 ),
154
- (torch .rand (10 , 18 ), torch .randint (0 , 18 , size = (10 ,)).long (), 1 ),
155
- (torch .rand (4 , 10 ), torch .randint (0 , 10 , size = (4 ,)).long (), 1 ),
156
- # 2-classes
157
- (torch .rand (4 , 2 ), torch .randint (0 , 2 , size = (4 ,)).long (), 1 ),
158
- (torch .rand (100 , 5 ), torch .randint (0 , 5 , size = (100 ,)).long (), 16 ),
159
- # Multiclass input data of shape (N, L) and (N, C, L)
160
- (torch .rand (10 , 4 , 5 ), torch .randint (0 , 4 , size = (10 , 5 )).long (), 1 ),
161
- (torch .rand (4 , 10 , 5 ), torch .randint (0 , 10 , size = (4 , 5 )).long (), 1 ),
162
- (torch .rand (100 , 9 , 7 ), torch .randint (0 , 9 , size = (100 , 7 )).long (), 16 ),
163
- # Multiclass input data of shape (N, H, W, ...) and (N, C, H, W, ...)
164
- (torch .rand (4 , 5 , 12 , 10 ), torch .randint (0 , 5 , size = (4 , 12 , 10 )).long (), 1 ),
165
- (torch .rand (100 , 3 , 8 , 8 ), torch .randint (0 , 3 , size = (100 , 8 , 8 )).long (), 16 ),
166
- ]
167
- return test_cases
168
-
169
- for _ in range (5 ):
170
- # check multiple random inputs as random exact occurencies are rare
171
- test_cases = get_test_cases ()
172
- for y_pred , y , batch_size in test_cases :
173
- _test (y_pred , y , batch_size )
152
+ y_pred , y , batch_size = test_data_multiclass
153
+ acc .reset ()
154
+ if batch_size > 1 :
155
+ # Batched Updates
156
+ n_iters = y .shape [0 ] // batch_size + 1
157
+ for i in range (n_iters ):
158
+ idx = i * batch_size
159
+ acc .update ((y_pred [idx : idx + batch_size ], y [idx : idx + batch_size ]))
160
+ else :
161
+ acc .update ((y_pred , y ))
162
+
163
+ np_y_pred = y_pred .numpy ().argmax (axis = 1 ).ravel ()
164
+ np_y = y .numpy ().ravel ()
165
+
166
+ assert acc ._type == "multiclass"
167
+ assert isinstance (acc .compute (), float )
168
+ assert accuracy_score (np_y , np_y_pred ) == pytest .approx (acc .compute ())
174
169
175
170
176
171
def to_numpy_multilabel (y ):
@@ -202,7 +197,7 @@ def test_multilabel_wrong_inputs():
202
197
203
198
204
199
@pytest .fixture (params = [item for item in range (12 )])
205
- def test_data (request ):
200
+ def test_data_multilabel (request ):
206
201
return [
207
202
# Multilabel input data of shape (N, C) and (N, C)
208
203
(torch .randint (0 , 2 , size = (10 , 4 )).long (), torch .randint (0 , 2 , size = (10 , 4 )).long (), 1 ),
@@ -226,10 +221,10 @@ def test_data(request):
226
221
227
222
228
223
@pytest .mark .parametrize ("n_times" , range (5 ))
229
- def test_multilabel_input (n_times , test_data ):
224
+ def test_multilabel_input (n_times , test_data_multilabel ):
230
225
acc = Accuracy (is_multilabel = True )
231
226
232
- y_pred , y , batch_size = test_data
227
+ y_pred , y , batch_size = test_data_multilabel
233
228
if batch_size > 1 :
234
229
n_iters = y .shape [0 ] // batch_size + 1
235
230
for i in range (n_iters ):
0 commit comments