Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 63 additions & 61 deletions tests/ignite/contrib/handlers/test_base_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,52 +169,52 @@ def test_opt_params_handler_on_non_torch_optimizers():
assert "lr/group_0" in res and res["lr/group_0"] == 0.1234


def test_attach():
@pytest.mark.parametrize(
"event, n_calls, kwargs",
[
(Events.ITERATION_STARTED, 50 * 5, {"a": 0}),
(Events.ITERATION_COMPLETED, 50 * 5, {}),
(Events.EPOCH_STARTED, 5, {}),
(Events.EPOCH_COMPLETED, 5, {}),
(Events.STARTED, 1, {}),
(Events.COMPLETED, 1, {}),
(Events.ITERATION_STARTED(every=10), 50 // 10 * 5, {}),
(Events.STARTED | Events.COMPLETED, 2, {}),
],
)
def test_attach(event, n_calls, kwargs):

n_epochs = 5
data = list(range(50))

def _test(event, n_calls, kwargs={}):

losses = torch.rand(n_epochs * len(data))
losses_iter = iter(losses)
losses = torch.rand(n_epochs * len(data))
losses_iter = iter(losses)

def update_fn(engine, batch):
return next(losses_iter)
def update_fn(engine, batch):
return next(losses_iter)

trainer = Engine(update_fn)
trainer = Engine(update_fn)

logger = DummyLogger()

mock_log_handler = MagicMock()

logger.attach(trainer, log_handler=mock_log_handler, event_name=event, **kwargs)
logger = DummyLogger()

trainer.run(data, max_epochs=n_epochs)
mock_log_handler = MagicMock()

if isinstance(event, EventsList):
events = [e for e in event]
else:
events = [event]
logger.attach(trainer, log_handler=mock_log_handler, event_name=event, **kwargs)

if len(kwargs) > 0:
calls = [call(trainer, logger, e, **kwargs) for e in events]
else:
calls = [call(trainer, logger, e) for e in events]
trainer.run(data, max_epochs=n_epochs)

mock_log_handler.assert_has_calls(calls)
assert mock_log_handler.call_count == n_calls
if isinstance(event, EventsList):
events = [e for e in event]
else:
events = [event]

_test(Events.ITERATION_STARTED, len(data) * n_epochs, kwargs={"a": 0})
_test(Events.ITERATION_COMPLETED, len(data) * n_epochs)
_test(Events.EPOCH_STARTED, n_epochs)
_test(Events.EPOCH_COMPLETED, n_epochs)
_test(Events.STARTED, 1)
_test(Events.COMPLETED, 1)
if len(kwargs) > 0:
calls = [call(trainer, logger, e, **kwargs) for e in events]
else:
calls = [call(trainer, logger, e) for e in events]

_test(Events.ITERATION_STARTED(every=10), len(data) // 10 * n_epochs)

_test(Events.STARTED | Events.COMPLETED, 2)
mock_log_handler.assert_has_calls(calls)
assert mock_log_handler.call_count == n_calls


def test_attach_wrong_event_name():
Expand Down Expand Up @@ -260,7 +260,19 @@ def update_fn(engine, batch):
assert mock_log_handler.call_count == n_calls


def test_as_context_manager():
@pytest.mark.parametrize(
"event, n_calls",
[
(Events.ITERATION_STARTED, 50 * 5),
(Events.ITERATION_COMPLETED, 50 * 5),
(Events.EPOCH_STARTED, 5),
(Events.EPOCH_COMPLETED, 5),
(Events.STARTED, 1),
(Events.COMPLETED, 1),
(Events.ITERATION_STARTED(every=10), 50 // 10 * 5),
],
)
def test_as_context_manager(event, n_calls):

n_epochs = 5
data = list(range(50))
Expand All @@ -272,42 +284,32 @@ def __init__(self, writer):
def close(self):
self.writer.close()

def _test(event, n_calls):
global close_counter
close_counter = 0

losses = torch.rand(n_epochs * len(data))
losses_iter = iter(losses)

def update_fn(engine, batch):
return next(losses_iter)
global close_counter
close_counter = 0

writer = MagicMock()
writer.close = MagicMock()
losses = torch.rand(n_epochs * len(data))
losses_iter = iter(losses)

with _DummyLogger(writer) as logger:
assert isinstance(logger, _DummyLogger)
def update_fn(engine, batch):
return next(losses_iter)

trainer = Engine(update_fn)
mock_log_handler = MagicMock()
writer = MagicMock()
writer.close = MagicMock()

logger.attach(trainer, log_handler=mock_log_handler, event_name=event)
with _DummyLogger(writer) as logger:
assert isinstance(logger, _DummyLogger)

trainer.run(data, max_epochs=n_epochs)
trainer = Engine(update_fn)
mock_log_handler = MagicMock()

mock_log_handler.assert_called_with(trainer, logger, event)
assert mock_log_handler.call_count == n_calls
logger.attach(trainer, log_handler=mock_log_handler, event_name=event)

writer.close.assert_called_once_with()
trainer.run(data, max_epochs=n_epochs)

_test(Events.ITERATION_STARTED, len(data) * n_epochs)
_test(Events.ITERATION_COMPLETED, len(data) * n_epochs)
_test(Events.EPOCH_STARTED, n_epochs)
_test(Events.EPOCH_COMPLETED, n_epochs)
_test(Events.STARTED, 1)
_test(Events.COMPLETED, 1)
mock_log_handler.assert_called_with(trainer, logger, event)
assert mock_log_handler.call_count == n_calls

_test(Events.ITERATION_STARTED(every=10), len(data) // 10 * n_epochs)
writer.close.assert_called_once_with()


def test_base_weights_handler_wrong_setup():
Expand Down