Skip to content

Commit 727150e

Browse files
nmcguire101nmcguire101
andauthored
Improving various tests with parametrize (#2552)
* Added: Improved test tests\ignite\contrib\metrics\regression\test__base.py using pytest parametrize * Undid test__base.py changes * Added dummy val to test_root_mean_squared_error.py * Generates unique values in test_root_mean_squared_error.py * Removed _test from test_root_mean_squared_error.py * Removed for loop from test_root_mean_squared_error.py * Cleaned up test_root_mean_squared_error.py * Improved test_accuracy.py * Improved multiclass test in test_accuracy.py * Replaced list with range * Delete (N Co-authored-by: nmcguire101 <nmcguire101@gmail.com>
1 parent 0d40173 commit 727150e

File tree

1 file changed

+121
-139
lines changed

1 file changed

+121
-139
lines changed

tests/ignite/metrics/test_accuracy.py

Lines changed: 121 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -62,57 +62,50 @@ def test_binary_wrong_inputs():
6262
acc.update((torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10, 5, 6)).long()))
6363

6464

65-
def test_binary_input():
66-
65+
@pytest.fixture(params=range(12))
66+
def test_data_binary(request):
67+
return [
68+
# Binary accuracy on input of shape (N, 1) or (N, )
69+
(torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10,)).long(), 1),
70+
(torch.randint(0, 2, size=(10, 1)).long(), torch.randint(0, 2, size=(10, 1)).long(), 1),
71+
# updated batches
72+
(torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 16),
73+
(torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 16),
74+
# Binary accuracy on input of shape (N, L)
75+
(torch.randint(0, 2, size=(10, 5)).long(), torch.randint(0, 2, size=(10, 5)).long(), 1),
76+
(torch.randint(0, 2, size=(10, 8)).long(), torch.randint(0, 2, size=(10, 8)).long(), 1),
77+
# updated batches
78+
(torch.randint(0, 2, size=(50, 5)).long(), torch.randint(0, 2, size=(50, 5)).long(), 16),
79+
(torch.randint(0, 2, size=(50, 8)).long(), torch.randint(0, 2, size=(50, 8)).long(), 16),
80+
# Binary accuracy on input of shape (N, H, W, ...)
81+
(torch.randint(0, 2, size=(4, 1, 12, 10)).long(), torch.randint(0, 2, size=(4, 1, 12, 10)).long(), 1),
82+
(torch.randint(0, 2, size=(15, 1, 20, 10)).long(), torch.randint(0, 2, size=(15, 1, 20, 10)).long(), 1),
83+
# updated batches
84+
(torch.randint(0, 2, size=(50, 1, 12, 10)).long(), torch.randint(0, 2, size=(50, 1, 12, 10)).long(), 16),
85+
(torch.randint(0, 2, size=(50, 1, 20, 10)).long(), torch.randint(0, 2, size=(50, 1, 20, 10)).long(), 16),
86+
][request.param]
87+
88+
89+
@pytest.mark.parametrize("n_times", range(5))
90+
def test_binary_input(n_times, test_data_binary):
6791
acc = Accuracy()
6892

69-
def _test(y_pred, y, batch_size):
70-
acc.reset()
71-
if batch_size > 1:
72-
n_iters = y.shape[0] // batch_size + 1
73-
for i in range(n_iters):
74-
idx = i * batch_size
75-
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
76-
else:
77-
acc.update((y_pred, y))
78-
79-
np_y = y.numpy().ravel()
80-
np_y_pred = y_pred.numpy().ravel()
81-
82-
assert acc._type == "binary"
83-
assert isinstance(acc.compute(), float)
84-
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())
85-
86-
def get_test_cases():
87-
88-
test_cases = [
89-
# Binary accuracy on input of shape (N, 1) or (N, )
90-
(torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10,)).long(), 1),
91-
(torch.randint(0, 2, size=(10, 1)).long(), torch.randint(0, 2, size=(10, 1)).long(), 1),
92-
# updated batches
93-
(torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 16),
94-
(torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 16),
95-
# Binary accuracy on input of shape (N, L)
96-
(torch.randint(0, 2, size=(10, 5)).long(), torch.randint(0, 2, size=(10, 5)).long(), 1),
97-
(torch.randint(0, 2, size=(10, 8)).long(), torch.randint(0, 2, size=(10, 8)).long(), 1),
98-
# updated batches
99-
(torch.randint(0, 2, size=(50, 5)).long(), torch.randint(0, 2, size=(50, 5)).long(), 16),
100-
(torch.randint(0, 2, size=(50, 8)).long(), torch.randint(0, 2, size=(50, 8)).long(), 16),
101-
# Binary accuracy on input of shape (N, H, W, ...)
102-
(torch.randint(0, 2, size=(4, 1, 12, 10)).long(), torch.randint(0, 2, size=(4, 1, 12, 10)).long(), 1),
103-
(torch.randint(0, 2, size=(15, 1, 20, 10)).long(), torch.randint(0, 2, size=(15, 1, 20, 10)).long(), 1),
104-
# updated batches
105-
(torch.randint(0, 2, size=(50, 1, 12, 10)).long(), torch.randint(0, 2, size=(50, 1, 12, 10)).long(), 16),
106-
(torch.randint(0, 2, size=(50, 1, 20, 10)).long(), torch.randint(0, 2, size=(50, 1, 20, 10)).long(), 16),
107-
]
108-
109-
return test_cases
110-
111-
for _ in range(5):
112-
# check multiple random inputs as random exact occurencies are rare
113-
test_cases = get_test_cases()
114-
for y_pred, y, n_iters in test_cases:
115-
_test(y_pred, y, n_iters)
93+
y_pred, y, batch_size = test_data_binary
94+
acc.reset()
95+
if batch_size > 1:
96+
n_iters = y.shape[0] // batch_size + 1
97+
for i in range(n_iters):
98+
idx = i * batch_size
99+
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
100+
else:
101+
acc.update((y_pred, y))
102+
103+
np_y = y.numpy().ravel()
104+
np_y_pred = y_pred.numpy().ravel()
105+
106+
assert acc._type == "binary"
107+
assert isinstance(acc.compute(), float)
108+
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())
116109

117110

118111
def test_multiclass_wrong_inputs():
@@ -131,53 +124,48 @@ def test_multiclass_wrong_inputs():
131124
acc.update((torch.rand(10), torch.randint(0, 5, size=(10, 5, 6)).long()))
132125

133126

134-
def test_multiclass_input():
127+
@pytest.fixture(params=range(11))
128+
def test_data_multiclass(request):
129+
return [
130+
# Multiclass input data of shape (N, ) and (N, C)
131+
(torch.rand(10, 4), torch.randint(0, 4, size=(10,)).long(), 1),
132+
(torch.rand(10, 10, 1), torch.randint(0, 18, size=(10, 1)).long(), 1),
133+
(torch.rand(10, 18), torch.randint(0, 18, size=(10,)).long(), 1),
134+
(torch.rand(4, 10), torch.randint(0, 10, size=(4,)).long(), 1),
135+
# 2-classes
136+
(torch.rand(4, 2), torch.randint(0, 2, size=(4,)).long(), 1),
137+
(torch.rand(100, 5), torch.randint(0, 5, size=(100,)).long(), 16),
138+
# Multiclass input data of shape (N, L) and (N, C, L)
139+
(torch.rand(10, 4, 5), torch.randint(0, 4, size=(10, 5)).long(), 1),
140+
(torch.rand(4, 10, 5), torch.randint(0, 10, size=(4, 5)).long(), 1),
141+
(torch.rand(100, 9, 7), torch.randint(0, 9, size=(100, 7)).long(), 16),
142+
# Multiclass input data of shape (N, H, W, ...) and (N, C, H, W, ...)
143+
(torch.rand(4, 5, 12, 10), torch.randint(0, 5, size=(4, 12, 10)).long(), 1),
144+
(torch.rand(100, 3, 8, 8), torch.randint(0, 3, size=(100, 8, 8)).long(), 16),
145+
][request.param]
146+
147+
148+
@pytest.mark.parametrize("n_times", range(5))
149+
def test_multiclass_input(n_times, test_data_multiclass):
135150
acc = Accuracy()
136151

137-
def _test(y_pred, y, batch_size):
138-
acc.reset()
139-
if batch_size > 1:
140-
# Batched Updates
141-
n_iters = y.shape[0] // batch_size + 1
142-
for i in range(n_iters):
143-
idx = i * batch_size
144-
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
145-
else:
146-
acc.update((y_pred, y))
147-
148-
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
149-
np_y = y.numpy().ravel()
150-
151-
assert acc._type == "multiclass"
152-
assert isinstance(acc.compute(), float)
153-
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())
154-
155-
def get_test_cases():
156-
157-
test_cases = [
158-
# Multiclass input data of shape (N, ) and (N, C)
159-
(torch.rand(10, 4), torch.randint(0, 4, size=(10,)).long(), 1),
160-
(torch.rand(10, 10, 1), torch.randint(0, 18, size=(10, 1)).long(), 1),
161-
(torch.rand(10, 18), torch.randint(0, 18, size=(10,)).long(), 1),
162-
(torch.rand(4, 10), torch.randint(0, 10, size=(4,)).long(), 1),
163-
# 2-classes
164-
(torch.rand(4, 2), torch.randint(0, 2, size=(4,)).long(), 1),
165-
(torch.rand(100, 5), torch.randint(0, 5, size=(100,)).long(), 16),
166-
# Multiclass input data of shape (N, L) and (N, C, L)
167-
(torch.rand(10, 4, 5), torch.randint(0, 4, size=(10, 5)).long(), 1),
168-
(torch.rand(4, 10, 5), torch.randint(0, 10, size=(4, 5)).long(), 1),
169-
(torch.rand(100, 9, 7), torch.randint(0, 9, size=(100, 7)).long(), 16),
170-
# Multiclass input data of shape (N, H, W, ...) and (N, C, H, W, ...)
171-
(torch.rand(4, 5, 12, 10), torch.randint(0, 5, size=(4, 12, 10)).long(), 1),
172-
(torch.rand(100, 3, 8, 8), torch.randint(0, 3, size=(100, 8, 8)).long(), 16),
173-
]
174-
return test_cases
175-
176-
for _ in range(5):
177-
# check multiple random inputs as random exact occurencies are rare
178-
test_cases = get_test_cases()
179-
for y_pred, y, batch_size in test_cases:
180-
_test(y_pred, y, batch_size)
152+
y_pred, y, batch_size = test_data_multiclass
153+
acc.reset()
154+
if batch_size > 1:
155+
# Batched Updates
156+
n_iters = y.shape[0] // batch_size + 1
157+
for i in range(n_iters):
158+
idx = i * batch_size
159+
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
160+
else:
161+
acc.update((y_pred, y))
162+
163+
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
164+
np_y = y.numpy().ravel()
165+
166+
assert acc._type == "multiclass"
167+
assert isinstance(acc.compute(), float)
168+
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())
181169

182170

183171
def to_numpy_multilabel(y):
@@ -208,55 +196,49 @@ def test_multilabel_wrong_inputs():
208196
acc.update((torch.randint(0, 2, size=(10, 1)), torch.randint(0, 2, size=(10, 1)).long()))
209197

210198

211-
def test_multilabel_input():
199+
@pytest.fixture(params=range(12))
200+
def test_data_multilabel(request):
201+
return [
202+
# Multilabel input data of shape (N, C) and (N, C)
203+
(torch.randint(0, 2, size=(10, 4)).long(), torch.randint(0, 2, size=(10, 4)).long(), 1),
204+
(torch.randint(0, 2, size=(10, 7)).long(), torch.randint(0, 2, size=(10, 7)).long(), 1),
205+
# updated batches
206+
(torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 16),
207+
(torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 16),
208+
# Multilabel input data of shape (N, H, W)
209+
(torch.randint(0, 2, size=(10, 5, 10)).long(), torch.randint(0, 2, size=(10, 5, 10)).long(), 1),
210+
(torch.randint(0, 2, size=(10, 4, 10)).long(), torch.randint(0, 2, size=(10, 4, 10)).long(), 1),
211+
# updated batches
212+
(torch.randint(0, 2, size=(50, 5, 10)).long(), torch.randint(0, 2, size=(50, 5, 10)).long(), 16),
213+
(torch.randint(0, 2, size=(50, 4, 10)).long(), torch.randint(0, 2, size=(50, 4, 10)).long(), 16),
214+
# Multilabel input data of shape (N, C, H, W, ...) and (N, C, H, W, ...)
215+
(torch.randint(0, 2, size=(4, 5, 12, 10)).long(), torch.randint(0, 2, size=(4, 5, 12, 10)).long(), 1),
216+
(torch.randint(0, 2, size=(4, 10, 12, 8)).long(), torch.randint(0, 2, size=(4, 10, 12, 8)).long(), 1),
217+
# updated batches
218+
(torch.randint(0, 2, size=(50, 5, 12, 10)).long(), torch.randint(0, 2, size=(50, 5, 12, 10)).long(), 16),
219+
(torch.randint(0, 2, size=(50, 10, 12, 8)).long(), torch.randint(0, 2, size=(50, 10, 12, 8)).long(), 16),
220+
][request.param]
221+
222+
223+
@pytest.mark.parametrize("n_times", range(5))
224+
def test_multilabel_input(n_times, test_data_multilabel):
212225
acc = Accuracy(is_multilabel=True)
213226

214-
def _test(y_pred, y, batch_size):
215-
acc.reset()
216-
if batch_size > 1:
217-
n_iters = y.shape[0] // batch_size + 1
218-
for i in range(n_iters):
219-
idx = i * batch_size
220-
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
221-
else:
222-
acc.update((y_pred, y))
227+
y_pred, y, batch_size = test_data_multilabel
228+
if batch_size > 1:
229+
n_iters = y.shape[0] // batch_size + 1
230+
for i in range(n_iters):
231+
idx = i * batch_size
232+
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
233+
else:
234+
acc.update((y_pred, y))
223235

224-
np_y_pred = to_numpy_multilabel(y_pred)
225-
np_y = to_numpy_multilabel(y)
236+
np_y_pred = to_numpy_multilabel(y_pred)
237+
np_y = to_numpy_multilabel(y)
226238

227-
assert acc._type == "multilabel"
228-
assert isinstance(acc.compute(), float)
229-
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())
230-
231-
def get_test_cases():
232-
233-
test_cases = [
234-
# Multilabel input data of shape (N, C) and (N, C)
235-
(torch.randint(0, 2, size=(10, 4)).long(), torch.randint(0, 2, size=(10, 4)).long(), 1),
236-
(torch.randint(0, 2, size=(10, 7)).long(), torch.randint(0, 2, size=(10, 7)).long(), 1),
237-
# updated batches
238-
(torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 16),
239-
(torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 16),
240-
# Multilabel input data of shape (N, H, W)
241-
(torch.randint(0, 2, size=(10, 5, 10)).long(), torch.randint(0, 2, size=(10, 5, 10)).long(), 1),
242-
(torch.randint(0, 2, size=(10, 4, 10)).long(), torch.randint(0, 2, size=(10, 4, 10)).long(), 1),
243-
# updated batches
244-
(torch.randint(0, 2, size=(50, 5, 10)).long(), torch.randint(0, 2, size=(50, 5, 10)).long(), 16),
245-
(torch.randint(0, 2, size=(50, 4, 10)).long(), torch.randint(0, 2, size=(50, 4, 10)).long(), 16),
246-
# Multilabel input data of shape (N, C, H, W, ...) and (N, C, H, W, ...)
247-
(torch.randint(0, 2, size=(4, 5, 12, 10)).long(), torch.randint(0, 2, size=(4, 5, 12, 10)).long(), 1),
248-
(torch.randint(0, 2, size=(4, 10, 12, 8)).long(), torch.randint(0, 2, size=(4, 10, 12, 8)).long(), 1),
249-
# updated batches
250-
(torch.randint(0, 2, size=(50, 5, 12, 10)).long(), torch.randint(0, 2, size=(50, 5, 12, 10)).long(), 16),
251-
(torch.randint(0, 2, size=(50, 10, 12, 8)).long(), torch.randint(0, 2, size=(50, 10, 12, 8)).long(), 16),
252-
]
253-
return test_cases
254-
255-
for _ in range(5):
256-
# check multiple random inputs as random exact occurencies are rare
257-
test_cases = get_test_cases()
258-
for y_pred, y, batch_size in test_cases:
259-
_test(y_pred, y, batch_size)
239+
assert acc._type == "multilabel"
240+
assert isinstance(acc.compute(), float)
241+
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())
260242

261243

262244
def test_incorrect_type():

0 commit comments

Comments
 (0)