Skip to content

Commit 3beaf27

Browse files
author
nmcguire101
committed
Improved test_accuracy.py
1 parent 1321339 commit 3beaf27

File tree

2 files changed

+79
-92
lines changed

2 files changed

+79
-92
lines changed

(N

Whitespace-only changes.

tests/ignite/metrics/test_accuracy.py

Lines changed: 79 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -62,57 +62,50 @@ def test_binary_wrong_inputs():
6262
acc.update((torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10, 5, 6)).long()))
6363

6464

65-
def test_binary_input():
66-
65+
@pytest.fixture(params=[item for item in range(11)])
66+
def test_data_binary(request):
67+
return [
68+
# Binary accuracy on input of shape (N, 1) or (N, )
69+
(torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10,)).long(), 1),
70+
(torch.randint(0, 2, size=(10, 1)).long(), torch.randint(0, 2, size=(10, 1)).long(), 1),
71+
# updated batches
72+
(torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 16),
73+
(torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 16),
74+
# Binary accuracy on input of shape (N, L)
75+
(torch.randint(0, 2, size=(10, 5)).long(), torch.randint(0, 2, size=(10, 5)).long(), 1),
76+
(torch.randint(0, 2, size=(10, 8)).long(), torch.randint(0, 2, size=(10, 8)).long(), 1),
77+
# updated batches
78+
(torch.randint(0, 2, size=(50, 5)).long(), torch.randint(0, 2, size=(50, 5)).long(), 16),
79+
(torch.randint(0, 2, size=(50, 8)).long(), torch.randint(0, 2, size=(50, 8)).long(), 16),
80+
# Binary accuracy on input of shape (N, H, W, ...)
81+
(torch.randint(0, 2, size=(4, 1, 12, 10)).long(), torch.randint(0, 2, size=(4, 1, 12, 10)).long(), 1),
82+
(torch.randint(0, 2, size=(15, 1, 20, 10)).long(), torch.randint(0, 2, size=(15, 1, 20, 10)).long(), 1),
83+
# updated batches
84+
(torch.randint(0, 2, size=(50, 1, 12, 10)).long(), torch.randint(0, 2, size=(50, 1, 12, 10)).long(), 16),
85+
(torch.randint(0, 2, size=(50, 1, 20, 10)).long(), torch.randint(0, 2, size=(50, 1, 20, 10)).long(), 16),
86+
][request.param]
87+
88+
89+
@pytest.mark.parametrize("n_times", range(5))
90+
def test_binary_input(n_times, test_data_binary):
6791
acc = Accuracy()
6892

69-
def _test(y_pred, y, batch_size):
70-
acc.reset()
71-
if batch_size > 1:
72-
n_iters = y.shape[0] // batch_size + 1
73-
for i in range(n_iters):
74-
idx = i * batch_size
75-
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
76-
else:
77-
acc.update((y_pred, y))
78-
79-
np_y = y.numpy().ravel()
80-
np_y_pred = y_pred.numpy().ravel()
81-
82-
assert acc._type == "binary"
83-
assert isinstance(acc.compute(), float)
84-
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())
85-
86-
def get_test_cases():
87-
88-
test_cases = [
89-
# Binary accuracy on input of shape (N, 1) or (N, )
90-
(torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10,)).long(), 1),
91-
(torch.randint(0, 2, size=(10, 1)).long(), torch.randint(0, 2, size=(10, 1)).long(), 1),
92-
# updated batches
93-
(torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 16),
94-
(torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 16),
95-
# Binary accuracy on input of shape (N, L)
96-
(torch.randint(0, 2, size=(10, 5)).long(), torch.randint(0, 2, size=(10, 5)).long(), 1),
97-
(torch.randint(0, 2, size=(10, 8)).long(), torch.randint(0, 2, size=(10, 8)).long(), 1),
98-
# updated batches
99-
(torch.randint(0, 2, size=(50, 5)).long(), torch.randint(0, 2, size=(50, 5)).long(), 16),
100-
(torch.randint(0, 2, size=(50, 8)).long(), torch.randint(0, 2, size=(50, 8)).long(), 16),
101-
# Binary accuracy on input of shape (N, H, W, ...)
102-
(torch.randint(0, 2, size=(4, 1, 12, 10)).long(), torch.randint(0, 2, size=(4, 1, 12, 10)).long(), 1),
103-
(torch.randint(0, 2, size=(15, 1, 20, 10)).long(), torch.randint(0, 2, size=(15, 1, 20, 10)).long(), 1),
104-
# updated batches
105-
(torch.randint(0, 2, size=(50, 1, 12, 10)).long(), torch.randint(0, 2, size=(50, 1, 12, 10)).long(), 16),
106-
(torch.randint(0, 2, size=(50, 1, 20, 10)).long(), torch.randint(0, 2, size=(50, 1, 20, 10)).long(), 16),
107-
]
93+
y_pred, y, batch_size = test_data_binary
94+
acc.reset()
95+
if batch_size > 1:
96+
n_iters = y.shape[0] // batch_size + 1
97+
for i in range(n_iters):
98+
idx = i * batch_size
99+
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
100+
else:
101+
acc.update((y_pred, y))
108102

109-
return test_cases
103+
np_y = y.numpy().ravel()
104+
np_y_pred = y_pred.numpy().ravel()
110105

111-
for _ in range(5):
112-
# check multiple random inputs as random exact occurencies are rare
113-
test_cases = get_test_cases()
114-
for y_pred, y, n_iters in test_cases:
115-
_test(y_pred, y, n_iters)
106+
assert acc._type == "binary"
107+
assert isinstance(acc.compute(), float)
108+
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())
116109

117110

118111
def test_multiclass_wrong_inputs():
@@ -208,55 +201,49 @@ def test_multilabel_wrong_inputs():
208201
acc.update((torch.randint(0, 2, size=(10, 1)), torch.randint(0, 2, size=(10, 1)).long()))
209202

210203

211-
def test_multilabel_input():
204+
@pytest.fixture(params=[item for item in range(12)])
205+
def test_data(request):
206+
return [
207+
# Multilabel input data of shape (N, C) and (N, C)
208+
(torch.randint(0, 2, size=(10, 4)).long(), torch.randint(0, 2, size=(10, 4)).long(), 1),
209+
(torch.randint(0, 2, size=(10, 7)).long(), torch.randint(0, 2, size=(10, 7)).long(), 1),
210+
# updated batches
211+
(torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 16),
212+
(torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 16),
213+
# Multilabel input data of shape (N, H, W)
214+
(torch.randint(0, 2, size=(10, 5, 10)).long(), torch.randint(0, 2, size=(10, 5, 10)).long(), 1),
215+
(torch.randint(0, 2, size=(10, 4, 10)).long(), torch.randint(0, 2, size=(10, 4, 10)).long(), 1),
216+
# updated batches
217+
(torch.randint(0, 2, size=(50, 5, 10)).long(), torch.randint(0, 2, size=(50, 5, 10)).long(), 16),
218+
(torch.randint(0, 2, size=(50, 4, 10)).long(), torch.randint(0, 2, size=(50, 4, 10)).long(), 16),
219+
# Multilabel input data of shape (N, C, H, W, ...) and (N, C, H, W, ...)
220+
(torch.randint(0, 2, size=(4, 5, 12, 10)).long(), torch.randint(0, 2, size=(4, 5, 12, 10)).long(), 1),
221+
(torch.randint(0, 2, size=(4, 10, 12, 8)).long(), torch.randint(0, 2, size=(4, 10, 12, 8)).long(), 1),
222+
# updated batches
223+
(torch.randint(0, 2, size=(50, 5, 12, 10)).long(), torch.randint(0, 2, size=(50, 5, 12, 10)).long(), 16),
224+
(torch.randint(0, 2, size=(50, 10, 12, 8)).long(), torch.randint(0, 2, size=(50, 10, 12, 8)).long(), 16),
225+
][request.param]
226+
227+
228+
@pytest.mark.parametrize("n_times", range(5))
229+
def test_multilabel_input(n_times, test_data):
212230
acc = Accuracy(is_multilabel=True)
213231

214-
def _test(y_pred, y, batch_size):
215-
acc.reset()
216-
if batch_size > 1:
217-
n_iters = y.shape[0] // batch_size + 1
218-
for i in range(n_iters):
219-
idx = i * batch_size
220-
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
221-
else:
222-
acc.update((y_pred, y))
223-
224-
np_y_pred = to_numpy_multilabel(y_pred)
225-
np_y = to_numpy_multilabel(y)
226-
227-
assert acc._type == "multilabel"
228-
assert isinstance(acc.compute(), float)
229-
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())
230-
231-
def get_test_cases():
232+
y_pred, y, batch_size = test_data
233+
if batch_size > 1:
234+
n_iters = y.shape[0] // batch_size + 1
235+
for i in range(n_iters):
236+
idx = i * batch_size
237+
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
238+
else:
239+
acc.update((y_pred, y))
232240

233-
test_cases = [
234-
# Multilabel input data of shape (N, C) and (N, C)
235-
(torch.randint(0, 2, size=(10, 4)).long(), torch.randint(0, 2, size=(10, 4)).long(), 1),
236-
(torch.randint(0, 2, size=(10, 7)).long(), torch.randint(0, 2, size=(10, 7)).long(), 1),
237-
# updated batches
238-
(torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 16),
239-
(torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 16),
240-
# Multilabel input data of shape (N, H, W)
241-
(torch.randint(0, 2, size=(10, 5, 10)).long(), torch.randint(0, 2, size=(10, 5, 10)).long(), 1),
242-
(torch.randint(0, 2, size=(10, 4, 10)).long(), torch.randint(0, 2, size=(10, 4, 10)).long(), 1),
243-
# updated batches
244-
(torch.randint(0, 2, size=(50, 5, 10)).long(), torch.randint(0, 2, size=(50, 5, 10)).long(), 16),
245-
(torch.randint(0, 2, size=(50, 4, 10)).long(), torch.randint(0, 2, size=(50, 4, 10)).long(), 16),
246-
# Multilabel input data of shape (N, C, H, W, ...) and (N, C, H, W, ...)
247-
(torch.randint(0, 2, size=(4, 5, 12, 10)).long(), torch.randint(0, 2, size=(4, 5, 12, 10)).long(), 1),
248-
(torch.randint(0, 2, size=(4, 10, 12, 8)).long(), torch.randint(0, 2, size=(4, 10, 12, 8)).long(), 1),
249-
# updated batches
250-
(torch.randint(0, 2, size=(50, 5, 12, 10)).long(), torch.randint(0, 2, size=(50, 5, 12, 10)).long(), 16),
251-
(torch.randint(0, 2, size=(50, 10, 12, 8)).long(), torch.randint(0, 2, size=(50, 10, 12, 8)).long(), 16),
252-
]
253-
return test_cases
241+
np_y_pred = to_numpy_multilabel(y_pred)
242+
np_y = to_numpy_multilabel(y)
254243

255-
for _ in range(5):
256-
# check multiple random inputs as random exact occurencies are rare
257-
test_cases = get_test_cases()
258-
for y_pred, y, batch_size in test_cases:
259-
_test(y_pred, y, batch_size)
244+
assert acc._type == "multilabel"
245+
assert isinstance(acc.compute(), float)
246+
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())
260247

261248

262249
def test_incorrect_type():

0 commit comments

Comments
 (0)