Skip to content

Commit 81705c4

Browse files
Parametrized tests for test_mean_pairwise_distance.py (#2628)
1 parent d437aef commit 81705c4

File tree

1 file changed

+29
-33
lines changed

1 file changed

+29
-33
lines changed

tests/ignite/metrics/test_mean_pairwise_distance.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,42 +17,38 @@ def test_zero_sample():
1717
mpd.compute()
1818

1919

20-
def test_compute():
20+
@pytest.fixture(params=[item for item in range(4)])
21+
def test_case(request):
22+
23+
return [
24+
(torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 1),
25+
(torch.randint(-20, 20, size=(100, 5)), torch.randint(-20, 20, size=(100, 5)), 1),
26+
# updated batches
27+
(torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 16),
28+
(torch.randint(-20, 20, size=(100, 5)), torch.randint(-20, 20, size=(100, 5)), 16),
29+
][request.param]
30+
31+
32+
@pytest.mark.parametrize("n_times", range(5))
33+
def test_compute(n_times, test_case):
2134

2235
mpd = MeanPairwiseDistance()
2336

24-
def _test(y_pred, y, batch_size):
25-
mpd.reset()
26-
if batch_size > 1:
27-
n_iters = y.shape[0] // batch_size + 1
28-
for i in range(n_iters):
29-
idx = i * batch_size
30-
mpd.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
31-
else:
32-
mpd.update((y_pred, y))
33-
34-
np_res = np.mean(torch.pairwise_distance(y_pred, y, p=mpd._p, eps=mpd._eps).numpy())
35-
36-
assert isinstance(mpd.compute(), float)
37-
assert pytest.approx(mpd.compute()) == np_res
38-
39-
def get_test_cases():
40-
41-
test_cases = [
42-
(torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 1),
43-
(torch.randint(-20, 20, size=(100, 5)), torch.randint(-20, 20, size=(100, 5)), 1),
44-
# updated batches
45-
(torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 16),
46-
(torch.randint(-20, 20, size=(100, 5)), torch.randint(-20, 20, size=(100, 5)), 16),
47-
]
48-
49-
return test_cases
50-
51-
for _ in range(5):
52-
# check multiple random inputs as random exact occurencies are rare
53-
test_cases = get_test_cases()
54-
for y_pred, y, batch_size in test_cases:
55-
_test(y_pred, y, batch_size)
37+
y_pred, y, batch_size = test_case
38+
39+
mpd.reset()
40+
if batch_size > 1:
41+
n_iters = y.shape[0] // batch_size + 1
42+
for i in range(n_iters):
43+
idx = i * batch_size
44+
mpd.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
45+
else:
46+
mpd.update((y_pred, y))
47+
48+
np_res = np.mean(torch.pairwise_distance(y_pred, y, p=mpd._p, eps=mpd._eps).numpy())
49+
50+
assert isinstance(mpd.compute(), float)
51+
assert pytest.approx(mpd.compute()) == np_res
5652

5753

5854
def _test_distrib_integration(device):

0 commit comments

Comments
 (0)