Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 31 additions & 36 deletions tests/ignite/metrics/test_mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,45 +17,40 @@ def test_zero_sample():
mse.compute()


def test_compute():
@pytest.fixture(params=[item for item in range(4)])
def test_case(request):
return [
(torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 1),
(torch.randint(-20, 20, size=(100, 5)), torch.randint(-20, 20, size=(100, 5)), 1),
# updated batches
(torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 16),
(torch.randint(-20, 20, size=(100, 5)), torch.randint(-20, 20, size=(100, 5)), 16),
][request.param]


@pytest.mark.parametrize("n_times", range(5))
def test_compute(n_times, test_case):

mse = MeanSquaredError()

def _test(y_pred, y, batch_size):
mse.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
mse.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
mse.update((y_pred, y))

np_y = y.numpy()
np_y_pred = y_pred.numpy()

np_res = np.power((np_y - np_y_pred), 2.0).sum() / np_y.shape[0]

assert isinstance(mse.compute(), float)
assert mse.compute() == np_res

def get_test_cases():

test_cases = [
(torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 1),
(torch.randint(-20, 20, size=(100, 5)), torch.randint(-20, 20, size=(100, 5)), 1),
# updated batches
(torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 16),
(torch.randint(-20, 20, size=(100, 5)), torch.randint(-20, 20, size=(100, 5)), 16),
]

return test_cases

for _ in range(5):
# check multiple random inputs as random exact occurencies are rare
test_cases = get_test_cases()
for y_pred, y, batch_size in test_cases:
_test(y_pred, y, batch_size)
y_pred, y, batch_size = test_case

mse.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
mse.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
mse.update((y_pred, y))

np_y = y.numpy()
np_y_pred = y_pred.numpy()

np_res = np.power((np_y - np_y_pred), 2.0).sum() / np_y.shape[0]

assert isinstance(mse.compute(), float)
assert mse.compute() == np_res


def _test_distrib_integration(device, tol=1e-6):
Expand Down