Skip to content

Commit 281a047

Browse files
author
nmcguire101
committed
Improved multiclass test in test_accuracy.py
1 parent 3beaf27 commit 281a047

File tree

1 file changed

+44
-49
lines changed

1 file changed

+44
-49
lines changed

tests/ignite/metrics/test_accuracy.py

Lines changed: 44 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_binary_wrong_inputs():
6262
acc.update((torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10, 5, 6)).long()))
6363

6464

65-
@pytest.fixture(params=[item for item in range(11)])
65+
@pytest.fixture(params=[item for item in range(12)])
6666
def test_data_binary(request):
6767
return [
6868
# Binary accuracy on input of shape (N, 1) or (N, )
@@ -124,53 +124,48 @@ def test_multiclass_wrong_inputs():
124124
acc.update((torch.rand(10), torch.randint(0, 5, size=(10, 5, 6)).long()))
125125

126126

127-
def test_multiclass_input():
127+
@pytest.fixture(params=[item for item in range(11)])
128+
def test_data_multiclass(request):
129+
return [
130+
# Multiclass input data of shape (N, ) and (N, C)
131+
(torch.rand(10, 4), torch.randint(0, 4, size=(10,)).long(), 1),
132+
(torch.rand(10, 10, 1), torch.randint(0, 18, size=(10, 1)).long(), 1),
133+
(torch.rand(10, 18), torch.randint(0, 18, size=(10,)).long(), 1),
134+
(torch.rand(4, 10), torch.randint(0, 10, size=(4,)).long(), 1),
135+
# 2-classes
136+
(torch.rand(4, 2), torch.randint(0, 2, size=(4,)).long(), 1),
137+
(torch.rand(100, 5), torch.randint(0, 5, size=(100,)).long(), 16),
138+
# Multiclass input data of shape (N, L) and (N, C, L)
139+
(torch.rand(10, 4, 5), torch.randint(0, 4, size=(10, 5)).long(), 1),
140+
(torch.rand(4, 10, 5), torch.randint(0, 10, size=(4, 5)).long(), 1),
141+
(torch.rand(100, 9, 7), torch.randint(0, 9, size=(100, 7)).long(), 16),
142+
# Multiclass input data of shape (N, H, W, ...) and (N, C, H, W, ...)
143+
(torch.rand(4, 5, 12, 10), torch.randint(0, 5, size=(4, 12, 10)).long(), 1),
144+
(torch.rand(100, 3, 8, 8), torch.randint(0, 3, size=(100, 8, 8)).long(), 16),
145+
][request.param]
146+
147+
148+
@pytest.mark.parametrize("n_times", range(5))
149+
def test_multiclass_input(n_times, test_data_multiclass):
128150
acc = Accuracy()
129151

130-
def _test(y_pred, y, batch_size):
131-
acc.reset()
132-
if batch_size > 1:
133-
# Batched Updates
134-
n_iters = y.shape[0] // batch_size + 1
135-
for i in range(n_iters):
136-
idx = i * batch_size
137-
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
138-
else:
139-
acc.update((y_pred, y))
140-
141-
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
142-
np_y = y.numpy().ravel()
143-
144-
assert acc._type == "multiclass"
145-
assert isinstance(acc.compute(), float)
146-
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())
147-
148-
def get_test_cases():
149-
150-
test_cases = [
151-
# Multiclass input data of shape (N, ) and (N, C)
152-
(torch.rand(10, 4), torch.randint(0, 4, size=(10,)).long(), 1),
153-
(torch.rand(10, 10, 1), torch.randint(0, 18, size=(10, 1)).long(), 1),
154-
(torch.rand(10, 18), torch.randint(0, 18, size=(10,)).long(), 1),
155-
(torch.rand(4, 10), torch.randint(0, 10, size=(4,)).long(), 1),
156-
# 2-classes
157-
(torch.rand(4, 2), torch.randint(0, 2, size=(4,)).long(), 1),
158-
(torch.rand(100, 5), torch.randint(0, 5, size=(100,)).long(), 16),
159-
# Multiclass input data of shape (N, L) and (N, C, L)
160-
(torch.rand(10, 4, 5), torch.randint(0, 4, size=(10, 5)).long(), 1),
161-
(torch.rand(4, 10, 5), torch.randint(0, 10, size=(4, 5)).long(), 1),
162-
(torch.rand(100, 9, 7), torch.randint(0, 9, size=(100, 7)).long(), 16),
163-
# Multiclass input data of shape (N, H, W, ...) and (N, C, H, W, ...)
164-
(torch.rand(4, 5, 12, 10), torch.randint(0, 5, size=(4, 12, 10)).long(), 1),
165-
(torch.rand(100, 3, 8, 8), torch.randint(0, 3, size=(100, 8, 8)).long(), 16),
166-
]
167-
return test_cases
168-
169-
for _ in range(5):
170-
# check multiple random inputs as random exact occurencies are rare
171-
test_cases = get_test_cases()
172-
for y_pred, y, batch_size in test_cases:
173-
_test(y_pred, y, batch_size)
152+
y_pred, y, batch_size = test_data_multiclass
153+
acc.reset()
154+
if batch_size > 1:
155+
# Batched Updates
156+
n_iters = y.shape[0] // batch_size + 1
157+
for i in range(n_iters):
158+
idx = i * batch_size
159+
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
160+
else:
161+
acc.update((y_pred, y))
162+
163+
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
164+
np_y = y.numpy().ravel()
165+
166+
assert acc._type == "multiclass"
167+
assert isinstance(acc.compute(), float)
168+
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())
174169

175170

176171
def to_numpy_multilabel(y):
@@ -202,7 +197,7 @@ def test_multilabel_wrong_inputs():
202197

203198

204199
@pytest.fixture(params=[item for item in range(12)])
205-
def test_data(request):
200+
def test_data_multilabel(request):
206201
return [
207202
# Multilabel input data of shape (N, C) and (N, C)
208203
(torch.randint(0, 2, size=(10, 4)).long(), torch.randint(0, 2, size=(10, 4)).long(), 1),
@@ -226,10 +221,10 @@ def test_data(request):
226221

227222

228223
@pytest.mark.parametrize("n_times", range(5))
229-
def test_multilabel_input(n_times, test_data):
224+
def test_multilabel_input(n_times, test_data_multilabel):
230225
acc = Accuracy(is_multilabel=True)
231226

232-
y_pred, y, batch_size = test_data
227+
y_pred, y, batch_size = test_data_multilabel
233228
if batch_size > 1:
234229
n_iters = y.shape[0] // batch_size + 1
235230
for i in range(n_iters):

0 commit comments

Comments
 (0)