Skip to content

Commit 368b170

Browse files
Add test for detach and epoch_bound
1 parent 390dd79 commit 368b170

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

tests/ignite/metrics/test_running_average.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from functools import partial
23
from itertools import accumulate
34

@@ -27,6 +28,23 @@ def test_wrong_input_args():
2728
with pytest.raises(ValueError, match=r"Argument device should be None if src is a Metric"):
2829
RunningAverage(Accuracy(), device="cpu")
2930

31+
with pytest.warns(UserWarning, match=r"`epoch_bound` is deprecated and will be removed in the future."):
32+
m = RunningAverage(Accuracy(), epoch_bound=True)
33+
e = Engine(lambda _, __: None)
34+
m.attach(e, "")
35+
36+
37+
@pytest.mark.filterwarnings("ignore")
38+
@pytest.mark.parametrize("epoch_bound, usage", [(False, RunningBatchWise()), (True, SingleEpochRunningBatchWise())])
39+
def test_epoch_bound(epoch_bound, usage):
40+
metric = RunningAverage(output_transform=lambda _: _, epoch_bound=epoch_bound)
41+
e1 = Engine(lambda _, __: None)
42+
e2 = Engine(lambda _, __: None)
43+
metric.attach(e1, "")
44+
metric.epoch_bound = None
45+
metric.attach(e2, "", usage)
46+
e1._event_handlers == e2._event_handlers
47+
3048

3149
@pytest.mark.parametrize("usage", [RunningBatchWise(), SingleEpochRunningBatchWise()])
3250
def test_integration_batchwise(usage):
@@ -181,6 +199,22 @@ def check_values(engine):
181199
trainer.run(data)
182200

183201

202+
@pytest.mark.filterwarnings("ignore")
203+
@pytest.mark.parametrize("epoch_bound", [True, False, None])
204+
@pytest.mark.parametrize("src", [Accuracy(), None])
205+
@pytest.mark.parametrize("usage", [RunningBatchWise(), SingleEpochRunningBatchWise(), RunningEpochWise()])
206+
def test_detach(epoch_bound, src, usage):
207+
m = RunningAverage(src, output_transform=(lambda _: _) if src is None else None, epoch_bound=epoch_bound)
208+
e = Engine(lambda _, __: None)
209+
with warnings.catch_warnings():
210+
m.attach(e, "m", usage)
211+
for event_handlers in e._event_handlers.values():
212+
assert len(event_handlers) != 0
213+
m.detach(e, usage)
214+
for event_handlers in e._event_handlers.values():
215+
assert len(event_handlers) == 0
216+
217+
184218
def test_output_is_tensor():
185219
m = RunningAverage(output_transform=lambda x: x)
186220
m.update(torch.rand(10, requires_grad=True).mean())

0 commit comments

Comments
 (0)