|
| 1 | +import warnings |
1 | 2 | from functools import partial
|
2 | 3 | from itertools import accumulate
|
3 | 4 |
|
@@ -27,6 +28,23 @@ def test_wrong_input_args():
|
27 | 28 | with pytest.raises(ValueError, match=r"Argument device should be None if src is a Metric"):
|
28 | 29 | RunningAverage(Accuracy(), device="cpu")
|
29 | 30 |
|
| 31 | + with pytest.warns(UserWarning, match=r"`epoch_bound` is deprecated and will be removed in the future."): |
| 32 | + m = RunningAverage(Accuracy(), epoch_bound=True) |
| 33 | + e = Engine(lambda _, __: None) |
| 34 | + m.attach(e, "") |
| 35 | + |
| 36 | + |
| 37 | +@pytest.mark.filterwarnings("ignore") |
| 38 | +@pytest.mark.parametrize("epoch_bound, usage", [(False, RunningBatchWise()), (True, SingleEpochRunningBatchWise())]) |
| 39 | +def test_epoch_bound(epoch_bound, usage): |
| 40 | + metric = RunningAverage(output_transform=lambda _: _, epoch_bound=epoch_bound) |
| 41 | + e1 = Engine(lambda _, __: None) |
| 42 | + e2 = Engine(lambda _, __: None) |
| 43 | + metric.attach(e1, "") |
| 44 | + metric.epoch_bound = None |
| 45 | + metric.attach(e2, "", usage) |
| 46 | + e1._event_handlers == e2._event_handlers |
| 47 | + |
30 | 48 |
|
31 | 49 | @pytest.mark.parametrize("usage", [RunningBatchWise(), SingleEpochRunningBatchWise()])
|
32 | 50 | def test_integration_batchwise(usage):
|
@@ -181,6 +199,22 @@ def check_values(engine):
|
181 | 199 | trainer.run(data)
|
182 | 200 |
|
183 | 201 |
|
| 202 | +@pytest.mark.filterwarnings("ignore") |
| 203 | +@pytest.mark.parametrize("epoch_bound", [True, False, None]) |
| 204 | +@pytest.mark.parametrize("src", [Accuracy(), None]) |
| 205 | +@pytest.mark.parametrize("usage", [RunningBatchWise(), SingleEpochRunningBatchWise(), RunningEpochWise()]) |
| 206 | +def test_detach(epoch_bound, src, usage): |
| 207 | + m = RunningAverage(src, output_transform=(lambda _: _) if src is None else None, epoch_bound=epoch_bound) |
| 208 | + e = Engine(lambda _, __: None) |
| 209 | + with warnings.catch_warnings(): |
| 210 | + m.attach(e, "m", usage) |
| 211 | + for event_handlers in e._event_handlers.values(): |
| 212 | + assert len(event_handlers) != 0 |
| 213 | + m.detach(e, usage) |
| 214 | + for event_handlers in e._event_handlers.values(): |
| 215 | + assert len(event_handlers) == 0 |
| 216 | + |
| 217 | + |
184 | 218 | def test_output_is_tensor():
|
185 | 219 | m = RunningAverage(output_transform=lambda x: x)
|
186 | 220 | m.update(torch.rand(10, requires_grad=True).mean())
|
|
0 commit comments