Skip to content

Commit d437aef

Browse files
Parametrized tests for test_mean_absolute_error.py (#2625)
1 parent 20c95b9 commit d437aef

File tree

1 file changed

+31
-35
lines changed

1 file changed

+31
-35
lines changed

tests/ignite/metrics/test_mean_absolute_error.py

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,44 +17,40 @@ def test_no_update():
1717
mae.compute()
1818

1919

20-
def test_compute():
20+
@pytest.fixture(params=[item for item in range(4)])
21+
def test_case(request):
22+
23+
return [
24+
(torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 1),
25+
(torch.randint(-10, 10, size=(100, 5)), torch.randint(-10, 10, size=(100, 5)), 1),
26+
# updated batches
27+
(torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 16),
28+
(torch.randint(-20, 20, size=(100, 5)), torch.randint(-20, 20, size=(100, 5)), 16),
29+
][request.param]
30+
31+
32+
@pytest.mark.parametrize("n_times", range(5))
33+
def test_compute(n_times, test_case):
2134

2235
mae = MeanAbsoluteError()
2336

24-
def _test(y_pred, y, batch_size):
25-
mae.reset()
26-
if batch_size > 1:
27-
n_iters = y.shape[0] // batch_size + 1
28-
for i in range(n_iters):
29-
idx = i * batch_size
30-
mae.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
31-
else:
32-
mae.update((y_pred, y, batch_size))
33-
34-
np_y = y.numpy()
35-
np_y_pred = y_pred.numpy()
36-
37-
np_res = (np.abs(np_y_pred - np_y)).sum() / np_y.shape[0]
38-
assert isinstance(mae.compute(), float)
39-
assert mae.compute() == np_res
40-
41-
def get_test_cases():
42-
43-
test_cases = [
44-
(torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 1),
45-
(torch.randint(-10, 10, size=(100, 5)), torch.randint(-10, 10, size=(100, 5)), 1),
46-
# updated batches
47-
(torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 16),
48-
(torch.randint(-20, 20, size=(100, 5)), torch.randint(-20, 20, size=(100, 5)), 16),
49-
]
50-
51-
return test_cases
52-
53-
for _ in range(5):
54-
# check multiple random inputs as random exact occurencies are rare
55-
test_cases = get_test_cases()
56-
for y_pred, y, batch_size in test_cases:
57-
_test(y_pred, y, batch_size)
37+
y_pred, y, batch_size = test_case
38+
39+
mae.reset()
40+
if batch_size > 1:
41+
n_iters = y.shape[0] // batch_size + 1
42+
for i in range(n_iters):
43+
idx = i * batch_size
44+
mae.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
45+
else:
46+
mae.update((y_pred, y, batch_size))
47+
48+
np_y = y.numpy()
49+
np_y_pred = y_pred.numpy()
50+
51+
np_res = (np.abs(np_y_pred - np_y)).sum() / np_y.shape[0]
52+
assert isinstance(mae.compute(), float)
53+
assert mae.compute() == np_res
5854

5955

6056
def _test_distrib_integration(device):

0 commit comments

Comments
 (0)