Skip to content

Commit 6f8ad2a

Browse files
Give the option to terminate the engine without firing Events.COMPLET… (#3309)
* Give the option to terminate the engine without firing Events.COMPLETED. The default behaviour is not changed. Note that even though Events.COMPLETED is not fired, its timer is updated. * Update ignite/engine/engine.py Co-authored-by: vfdev <vfdev.5@gmail.com> * Update ignite/engine/engine.py Co-authored-by: vfdev <vfdev.5@gmail.com> * Update ignite/engine/engine.py Co-authored-by: vfdev <vfdev.5@gmail.com> * Update ignite/engine/engine.py Co-authored-by: vfdev <vfdev.5@gmail.com> * Update ignite/engine/events.py Co-authored-by: vfdev <vfdev.5@gmail.com> * Argument `skip_event_completed` renamed to `skip_completed` * - Fixed docs broken links. - Do not update self.state.times[Events.COMPLETED.name] if terminated - Fixed unit test * Update ignite/engine/engine.py Co-authored-by: vfdev <vfdev.5@gmail.com> * Refactoring and patching. - Engine time logging moved out of the if clause. In the log message "completed" has been replaced with "finished" to avoid confusion. - Same changes applied to the method `_internal_run_legacy()` * Restored .gitignore Sorry for accidentally including it into the previous commit! * Update ignite/engine/events.py * Fixed typo in test_engine.py * Parametrized test for engine.terminate(skip_completed) * Update event table * Fixed documentation --------- Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent 4f46210 commit 6f8ad2a

File tree

4 files changed

+76
-36
lines changed

4 files changed

+76
-36
lines changed

ignite/engine/engine.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
140140
self._process_function = process_function
141141
self.last_event_name: Optional[Events] = None
142142
self.should_terminate = False
143+
self.skip_completed_after_termination = False
143144
self.should_terminate_single_epoch = False
144145
self.should_interrupt = False
145146
self.state = State()
@@ -538,7 +539,7 @@ def call_interrupt():
538539
self.logger.info("interrupt signaled. Engine will interrupt the run after current iteration is finished.")
539540
self.should_interrupt = True
540541

541-
def terminate(self) -> None:
542+
def terminate(self, skip_completed: bool = False) -> None:
542543
"""Sends terminate signal to the engine, so that it terminates completely the run. The run is
543544
terminated after the event on which ``terminate`` method was called. The following events are triggered:
544545
@@ -547,6 +548,9 @@ def terminate(self) -> None:
547548
- :attr:`~ignite.engine.events.Events.TERMINATE`
548549
- :attr:`~ignite.engine.events.Events.COMPLETED`
549550
551+
Args:
552+
skip_completed: if True, the event :attr:`~ignite.engine.events.Events.COMPLETED` is not fired after
553+
:attr:`~ignite.engine.events.Events.TERMINATE`. Default is False.
550554
551555
Examples:
552556
.. testcode::
@@ -617,9 +621,12 @@ def terminate():
617621
.. versionchanged:: 0.4.10
618622
Behaviour changed, for details see https://github.com/pytorch/ignite/issues/2669
619623
624+
.. versionchanged:: 0.5.2
625+
Added `skip_completed` flag
620626
"""
621627
self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.")
622628
self.should_terminate = True
629+
self.skip_completed_after_termination = skip_completed
623630

624631
def terminate_epoch(self) -> None:
625632
"""Sends terminate signal to the engine, so that it terminates the current epoch. The run
@@ -993,13 +1000,17 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]:
9931000
time_taken = time.time() - start_time
9941001
# time is available for handlers but must be updated after fire
9951002
self.state.times[Events.COMPLETED.name] = time_taken
996-
handlers_start_time = time.time()
997-
self._fire_event(Events.COMPLETED)
998-
time_taken += time.time() - handlers_start_time
999-
# update time wrt handlers
1000-
self.state.times[Events.COMPLETED.name] = time_taken
1003+
1004+
# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
1005+
if not (self.should_terminate and self.skip_completed_after_termination):
1006+
handlers_start_time = time.time()
1007+
self._fire_event(Events.COMPLETED)
1008+
time_taken += time.time() - handlers_start_time
1009+
# update time wrt handlers
1010+
self.state.times[Events.COMPLETED.name] = time_taken
1011+
10011012
hours, mins, secs = _to_hours_mins_secs(time_taken)
1002-
self.logger.info(f"Engine run complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
1013+
self.logger.info(f"Engine run finished. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
10031014

10041015
except BaseException as e:
10051016
self._dataloader_iter = None
@@ -1174,13 +1185,17 @@ def _internal_run_legacy(self) -> State:
11741185
time_taken = time.time() - start_time
11751186
# time is available for handlers but must be updated after fire
11761187
self.state.times[Events.COMPLETED.name] = time_taken
1177-
handlers_start_time = time.time()
1178-
self._fire_event(Events.COMPLETED)
1179-
time_taken += time.time() - handlers_start_time
1180-
# update time wrt handlers
1181-
self.state.times[Events.COMPLETED.name] = time_taken
1188+
1189+
# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
1190+
if not (self.should_terminate and self.skip_completed_after_termination):
1191+
handlers_start_time = time.time()
1192+
self._fire_event(Events.COMPLETED)
1193+
time_taken += time.time() - handlers_start_time
1194+
# update time wrt handlers
1195+
self.state.times[Events.COMPLETED.name] = time_taken
1196+
11821197
hours, mins, secs = _to_hours_mins_secs(time_taken)
1183-
self.logger.info(f"Engine run complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
1198+
self.logger.info(f"Engine run finished. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
11841199

11851200
except BaseException as e:
11861201
self._dataloader_iter = None

ignite/engine/events.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -259,36 +259,47 @@ class Events(EventEnum):
259259
- TERMINATE_SINGLE_EPOCH : triggered when the run is about to end the current epoch,
260260
after receiving a :meth:`~ignite.engine.engine.Engine.terminate_epoch()` or
261261
:meth:`~ignite.engine.engine.Engine.terminate()` call.
262+
- EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even
263+
when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called.
262264
263265
- TERMINATE : triggered when the run is about to end completely,
264266
after receiving :meth:`~ignite.engine.engine.Engine.terminate()` call.
265267
266-
- EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even
267-
when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called.
268-
- COMPLETED : triggered when engine's run is completed
268+
- COMPLETED : triggered when engine's run is completed or terminated with
269+
:meth:`~ignite.engine.engine.Engine.terminate()`, unless the flag
270+
`skip_completed` is set to True.
269271
270272
The table below illustrates which events are triggered when various termination methods are called.
271273
272274
.. list-table::
273-
:widths: 24 25 33 18
275+
:widths: 35 38 28 20 20
274276
:header-rows: 1
275277
276278
* - Method
277-
- EVENT_COMPLETED
278279
- TERMINATE_SINGLE_EPOCH
280+
- EPOCH_COMPLETED
279281
- TERMINATE
282+
- COMPLETED
280283
* - no termination
281-
- ✔
282284
- ✗
285+
- ✔
283286
- ✗
287+
- ✔
284288
* - :meth:`~ignite.engine.engine.Engine.terminate_epoch()`
285289
- ✔
286290
- ✔
287291
- ✗
292+
- ✔
288293
* - :meth:`~ignite.engine.engine.Engine.terminate()`
289294
- ✗
290295
- ✔
291296
- ✔
297+
- ✔
298+
* - :meth:`~ignite.engine.engine.Engine.terminate()` with `skip_completed=True`
299+
- ✗
300+
- ✔
301+
- ✔
302+
- ✗
292303
293304
Since v0.3.0, Events become more flexible and allow to pass an event filter to the Engine:
294305
@@ -357,7 +368,7 @@ class CustomEvents(EventEnum):
357368
STARTED = "started"
358369
"""triggered when engine's run is started."""
359370
COMPLETED = "completed"
360-
"""triggered when engine's run is completed"""
371+
"""triggered when engine's run is completed, or after receiving terminate() call."""
361372

362373
ITERATION_STARTED = "iteration_started"
363374
"""triggered when an iteration is started."""

tests/ignite/contrib/engines/test_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from torch.utils.data.distributed import DistributedSampler
99

1010
import ignite.distributed as idist
11-
1211
import ignite.handlers as handlers
1312
from ignite.contrib.engines.common import (
1413
_setup_logging,

tests/ignite/engine/test_engine.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,14 @@ class TestEngine:
4040
def set_interrupt_resume_enabled(self, interrupt_resume_enabled):
4141
Engine.interrupt_resume_enabled = interrupt_resume_enabled
4242

43-
def test_terminate(self):
43+
@pytest.mark.parametrize("skip_completed", [True, False])
44+
def test_terminate(self, skip_completed):
4445
engine = Engine(lambda e, b: 1)
4546
assert not engine.should_terminate
46-
engine.terminate()
47+
assert not engine.skip_completed_after_termination
48+
engine.terminate(skip_completed)
4749
assert engine.should_terminate
50+
assert engine.skip_completed_after_termination == skip_completed
4851

4952
def test_invalid_process_raises_with_invalid_signature(self):
5053
with pytest.raises(ValueError, match=r"Engine must be given a processing function in order to run"):
@@ -236,25 +239,32 @@ def check_iter_and_data():
236239
assert num_calls_check_iter_epoch == 1
237240

238241
@pytest.mark.parametrize(
239-
"terminate_event, e, i",
242+
"terminate_event, e, i, skip_completed",
240243
[
241-
(Events.STARTED, 0, 0),
242-
(Events.EPOCH_STARTED(once=2), 2, None),
243-
(Events.EPOCH_COMPLETED(once=2), 2, None),
244-
(Events.GET_BATCH_STARTED(once=12), None, 12),
245-
(Events.GET_BATCH_COMPLETED(once=12), None, 12),
246-
(Events.ITERATION_STARTED(once=14), None, 14),
247-
(Events.ITERATION_COMPLETED(once=14), None, 14),
244+
(Events.STARTED, 0, 0, True),
245+
(Events.EPOCH_STARTED(once=2), 2, None, True),
246+
(Events.EPOCH_COMPLETED(once=2), 2, None, True),
247+
(Events.GET_BATCH_STARTED(once=12), None, 12, True),
248+
(Events.GET_BATCH_COMPLETED(once=12), None, 12, False),
249+
(Events.ITERATION_STARTED(once=14), None, 14, True),
250+
(Events.ITERATION_COMPLETED(once=14), None, 14, True),
251+
(Events.STARTED, 0, 0, False),
252+
(Events.EPOCH_STARTED(once=2), 2, None, False),
253+
(Events.EPOCH_COMPLETED(once=2), 2, None, False),
254+
(Events.GET_BATCH_STARTED(once=12), None, 12, False),
255+
(Events.GET_BATCH_COMPLETED(once=12), None, 12, False),
256+
(Events.ITERATION_STARTED(once=14), None, 14, False),
257+
(Events.ITERATION_COMPLETED(once=14), None, 14, False),
248258
],
249259
)
250-
def test_terminate_events_sequence(self, terminate_event, e, i):
260+
def test_terminate_events_sequence(self, terminate_event, e, i, skip_completed):
251261
engine = RecordedEngine(MagicMock(return_value=1))
252262
data = range(10)
253263
max_epochs = 5
254264

255265
@engine.on(terminate_event)
256266
def call_terminate():
257-
engine.terminate()
267+
engine.terminate(skip_completed)
258268

259269
@engine.on(Events.EXCEPTION_RAISED)
260270
def assert_no_exceptions(ee):
@@ -271,10 +281,15 @@ def assert_no_exceptions(ee):
271281
if e is None:
272282
e = i // len(data) + 1
273283

284+
if skip_completed:
285+
assert engine.called_events[-1] == (e, i, Events.TERMINATE)
286+
assert engine.called_events[-2] == (e, i, terminate_event)
287+
else:
288+
assert engine.called_events[-1] == (e, i, Events.COMPLETED)
289+
assert engine.called_events[-2] == (e, i, Events.TERMINATE)
290+
assert engine.called_events[-3] == (e, i, terminate_event)
291+
274292
assert engine.called_events[0] == (0, 0, Events.STARTED)
275-
assert engine.called_events[-1] == (e, i, Events.COMPLETED)
276-
assert engine.called_events[-2] == (e, i, Events.TERMINATE)
277-
assert engine.called_events[-3] == (e, i, terminate_event)
278293
assert engine._dataloader_iter is None
279294

280295
@pytest.mark.parametrize("data, epoch_length", [(None, 10), (range(10), None)])

0 commit comments

Comments
 (0)