Skip to content

Commit 15bb7b6

Browse files
Parametrized tests for tests/ignite/contrib/hadlers/test_abse_logger.py (#2617)
1 parent a944081 commit 15bb7b6

File tree

1 file changed

+63
-61
lines changed

1 file changed

+63
-61
lines changed

tests/ignite/contrib/handlers/test_base_logger.py

Lines changed: 63 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -169,52 +169,52 @@ def test_opt_params_handler_on_non_torch_optimizers():
169169
assert "lr/group_0" in res and res["lr/group_0"] == 0.1234
170170

171171

172-
def test_attach():
172+
@pytest.mark.parametrize(
173+
"event, n_calls, kwargs",
174+
[
175+
(Events.ITERATION_STARTED, 50 * 5, {"a": 0}),
176+
(Events.ITERATION_COMPLETED, 50 * 5, {}),
177+
(Events.EPOCH_STARTED, 5, {}),
178+
(Events.EPOCH_COMPLETED, 5, {}),
179+
(Events.STARTED, 1, {}),
180+
(Events.COMPLETED, 1, {}),
181+
(Events.ITERATION_STARTED(every=10), 50 // 10 * 5, {}),
182+
(Events.STARTED | Events.COMPLETED, 2, {}),
183+
],
184+
)
185+
def test_attach(event, n_calls, kwargs):
173186

174187
n_epochs = 5
175188
data = list(range(50))
176189

177-
def _test(event, n_calls, kwargs={}):
178-
179-
losses = torch.rand(n_epochs * len(data))
180-
losses_iter = iter(losses)
190+
losses = torch.rand(n_epochs * len(data))
191+
losses_iter = iter(losses)
181192

182-
def update_fn(engine, batch):
183-
return next(losses_iter)
193+
def update_fn(engine, batch):
194+
return next(losses_iter)
184195

185-
trainer = Engine(update_fn)
196+
trainer = Engine(update_fn)
186197

187-
logger = DummyLogger()
188-
189-
mock_log_handler = MagicMock()
190-
191-
logger.attach(trainer, log_handler=mock_log_handler, event_name=event, **kwargs)
198+
logger = DummyLogger()
192199

193-
trainer.run(data, max_epochs=n_epochs)
200+
mock_log_handler = MagicMock()
194201

195-
if isinstance(event, EventsList):
196-
events = [e for e in event]
197-
else:
198-
events = [event]
202+
logger.attach(trainer, log_handler=mock_log_handler, event_name=event, **kwargs)
199203

200-
if len(kwargs) > 0:
201-
calls = [call(trainer, logger, e, **kwargs) for e in events]
202-
else:
203-
calls = [call(trainer, logger, e) for e in events]
204+
trainer.run(data, max_epochs=n_epochs)
204205

205-
mock_log_handler.assert_has_calls(calls)
206-
assert mock_log_handler.call_count == n_calls
206+
if isinstance(event, EventsList):
207+
events = [e for e in event]
208+
else:
209+
events = [event]
207210

208-
_test(Events.ITERATION_STARTED, len(data) * n_epochs, kwargs={"a": 0})
209-
_test(Events.ITERATION_COMPLETED, len(data) * n_epochs)
210-
_test(Events.EPOCH_STARTED, n_epochs)
211-
_test(Events.EPOCH_COMPLETED, n_epochs)
212-
_test(Events.STARTED, 1)
213-
_test(Events.COMPLETED, 1)
211+
if len(kwargs) > 0:
212+
calls = [call(trainer, logger, e, **kwargs) for e in events]
213+
else:
214+
calls = [call(trainer, logger, e) for e in events]
214215

215-
_test(Events.ITERATION_STARTED(every=10), len(data) // 10 * n_epochs)
216-
217-
_test(Events.STARTED | Events.COMPLETED, 2)
216+
mock_log_handler.assert_has_calls(calls)
217+
assert mock_log_handler.call_count == n_calls
218218

219219

220220
def test_attach_wrong_event_name():
@@ -260,7 +260,19 @@ def update_fn(engine, batch):
260260
assert mock_log_handler.call_count == n_calls
261261

262262

263-
def test_as_context_manager():
263+
@pytest.mark.parametrize(
264+
"event, n_calls",
265+
[
266+
(Events.ITERATION_STARTED, 50 * 5),
267+
(Events.ITERATION_COMPLETED, 50 * 5),
268+
(Events.EPOCH_STARTED, 5),
269+
(Events.EPOCH_COMPLETED, 5),
270+
(Events.STARTED, 1),
271+
(Events.COMPLETED, 1),
272+
(Events.ITERATION_STARTED(every=10), 50 // 10 * 5),
273+
],
274+
)
275+
def test_as_context_manager(event, n_calls):
264276

265277
n_epochs = 5
266278
data = list(range(50))
@@ -272,42 +284,32 @@ def __init__(self, writer):
272284
def close(self):
273285
self.writer.close()
274286

275-
def _test(event, n_calls):
276-
global close_counter
277-
close_counter = 0
278-
279-
losses = torch.rand(n_epochs * len(data))
280-
losses_iter = iter(losses)
281-
282-
def update_fn(engine, batch):
283-
return next(losses_iter)
287+
global close_counter
288+
close_counter = 0
284289

285-
writer = MagicMock()
286-
writer.close = MagicMock()
290+
losses = torch.rand(n_epochs * len(data))
291+
losses_iter = iter(losses)
287292

288-
with _DummyLogger(writer) as logger:
289-
assert isinstance(logger, _DummyLogger)
293+
def update_fn(engine, batch):
294+
return next(losses_iter)
290295

291-
trainer = Engine(update_fn)
292-
mock_log_handler = MagicMock()
296+
writer = MagicMock()
297+
writer.close = MagicMock()
293298

294-
logger.attach(trainer, log_handler=mock_log_handler, event_name=event)
299+
with _DummyLogger(writer) as logger:
300+
assert isinstance(logger, _DummyLogger)
295301

296-
trainer.run(data, max_epochs=n_epochs)
302+
trainer = Engine(update_fn)
303+
mock_log_handler = MagicMock()
297304

298-
mock_log_handler.assert_called_with(trainer, logger, event)
299-
assert mock_log_handler.call_count == n_calls
305+
logger.attach(trainer, log_handler=mock_log_handler, event_name=event)
300306

301-
writer.close.assert_called_once_with()
307+
trainer.run(data, max_epochs=n_epochs)
302308

303-
_test(Events.ITERATION_STARTED, len(data) * n_epochs)
304-
_test(Events.ITERATION_COMPLETED, len(data) * n_epochs)
305-
_test(Events.EPOCH_STARTED, n_epochs)
306-
_test(Events.EPOCH_COMPLETED, n_epochs)
307-
_test(Events.STARTED, 1)
308-
_test(Events.COMPLETED, 1)
309+
mock_log_handler.assert_called_with(trainer, logger, event)
310+
assert mock_log_handler.call_count == n_calls
309311

310-
_test(Events.ITERATION_STARTED(every=10), len(data) // 10 * n_epochs)
312+
writer.close.assert_called_once_with()
311313

312314

313315
def test_base_weights_handler_wrong_setup():

0 commit comments

Comments
 (0)