Skip to content

Commit e45846f

Browse files
Parametrized tests for tests/ignite/metrics/test_accumulation (#2620)
* parametrized tests for tests/ignite/metrics/test_accumulation - test_integration * parametrized size for test_integration * Improved the update function and the assert code * Added 2 separate parameters for shape
1 parent 9fca904 commit e45846f

File tree

1 file changed

+20
-30
lines changed

1 file changed

+20
-30
lines changed

tests/ignite/metrics/test_accumulation.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ def _geom_mean(t):
9797
return np.exp(np.mean(np.log(np_t), axis=0))
9898

9999

100+
def _mean(y_true):
101+
return y_true.mean(dim=0).numpy()
102+
103+
100104
def test_geom_average():
101105

102106
with pytest.raises(NotComputableError):
@@ -129,44 +133,30 @@ def test_geom_average():
129133
np.testing.assert_almost_equal(m.numpy(), _geom_mean(y_true.reshape(-1, 10)), decimal=5)
130134

131135

132-
def test_integration():
133-
def _test(metric_cls, true_result_fn):
134-
135-
size = 100
136-
custom_variable = 10.0 + 5.0 * torch.rand(size, 12)
137-
138-
def update_fn(engine, batch):
139-
return 0, custom_variable[engine.state.iteration - 1]
140-
141-
engine = Engine(update_fn)
142-
143-
custom_var_mean = metric_cls(output_transform=lambda output: output[1])
144-
custom_var_mean.attach(engine, "agg_custom_var")
136+
@pytest.mark.parametrize("metric_cls, true_result_fn", [(Average, _mean), (GeometricAverage, _geom_mean)])
137+
@pytest.mark.parametrize("shape", [[100, 12], [100]])
138+
def test_integration(metric_cls, true_result_fn, shape):
145139

146-
state = engine.run([0] * size)
147-
np.testing.assert_almost_equal(
148-
state.metrics["agg_custom_var"].numpy(), true_result_fn(custom_variable), decimal=5
149-
)
140+
assert len(shape) > 0 and len(shape) < 3
150141

151-
size = 100
152-
custom_variable = 10.0 + 5.0 * torch.rand(size)
142+
custom_variable = 10.0 + 5.0 * torch.rand(shape)
153143

154-
def update_fn(engine, batch):
155-
return 0, custom_variable[engine.state.iteration - 1].item()
144+
def update_fn(engine, batch):
156145

157-
engine = Engine(update_fn)
146+
output = custom_variable[engine.state.iteration - 1]
147+
output = output.item() if output.ndimension() < 1 else output
148+
return 0, output
158149

159-
custom_var_mean = metric_cls(output_transform=lambda output: output[1])
160-
custom_var_mean.attach(engine, "agg_custom_var")
150+
engine = Engine(update_fn)
161151

162-
state = engine.run([0] * size)
163-
assert state.metrics["agg_custom_var"] == pytest.approx(true_result_fn(custom_variable))
152+
custom_var_mean = metric_cls(output_transform=lambda output: output[1])
153+
custom_var_mean.attach(engine, "agg_custom_var")
164154

165-
def _mean(y_true):
166-
return y_true.mean(dim=0).numpy()
155+
state = engine.run([0] * shape[0])
167156

168-
_test(Average, _mean)
169-
_test(GeometricAverage, _geom_mean)
157+
np.testing.assert_almost_equal(
158+
np.array(state.metrics["agg_custom_var"]), true_result_fn(custom_variable), decimal=5
159+
)
170160

171161

172162
def test_compute_mean_std():

0 commit comments

Comments
 (0)