Skip to content

Enabling engine to run single epochs #1371

@alxlampe

Description

@alxlampe

🚀 Feature

Problem

I am using multiple engines in a nested way. That means, that if e.g. the main engine fires Events.EPOCH_COMPLETED, another child engine is attached to this event and shall run only one epoch. A solution would be to run the child engine with engine.run(max_epochs=1) but then, the engine fires setup and teardown events like Events.STARTED and Events.COMPLETED each time I call engine.run(max_epochs=1) even though those events are for the purpose to only be fired one time, as far as I understand.
Since my child engine must setup and teardown things, I could attach event handlers to the main engine, but the handlers I want to attach do not know that a main engine exists. The handlers shouldn't have any access to the main engine.

Solution

I need some functionality that the engine can do the following (This is just an example with a bad but possible way of implementing this):

engine.run_epoch(max_epochs=3)  # runs setup and first epoch, fires events from `STARTED` to `EPOCH_COMPLETED` 
engine.run_epoch(max_epochs=3) # runs second epoch, fires events from `EPOCH_STARTED`to `EPOCH_COMPLETED` 
engine.run_epoche(max_epochs=3) # runs last epoch and teardown, fires events from `EPOCH_STARTED` to `COMPLETED`

Instead of calling a function, one could create an iterable object from engine.run and get the same behavior in a nicer way:

epoch_iterator = iterable_engine.run(max_epochs=3)
next(epoch_iterator)  # runs setup and first episode, fires events from `STARTED` to `EPOCH_COMPLETED` 
next(epoch_iterator)  # runs second episode, fires events from `EPOCH_STARTED`to `EPOCH_COMPLETED` 
next(epoch_iterator)  # runs last episode and teardown, fires events from `EPOCH_STARTED` to `COMPLETED`

Or one can use loops:

iterable_engine = IterableEngine(lambda x, y: 0.)
iterable_engine.add_event_handler(Events.STARTED, lambda x: print("started"))
iterable_engine.add_event_handler(Events.EPOCH_STARTED, lambda x: print("epoch started"))
iterable_engine.add_event_handler(Events.EPOCH_COMPLETED, lambda x: print("epoch completed"))
iterable_engine.add_event_handler(Events.COMPLETED, lambda x: print("completed"))

epoch_iterator = iterable_engine.run([1], max_epochs=3)
for state in epoch_iterator:
    print("This is outside engine.run")

The output is:

started
epoch started
epoch completed
This is outside engine.run
epoch started
epoch completed
This is outside engine.run
epoch started
epoch completed
This is outside engine.run
completed

I added the code at the bottom where I subclass from Engine and overload the _internal_run method with a copy of the original method and added one line, where I add the yield statement. You can execute it and it outputs the example.
To switch between the actual and this behavior, one could put yield into an if statement and pass an additional argument to engine.run, e.g. engine.run(max_epochs=3, return_generator=True) or set a flag of the engine to enable this functionality.

What do you think?

Code:

import time

from ignite._utils import _to_hours_mins_secs
from ignite.engine import Engine
from ignite.engine import Events
from ignite.engine import State


class IterableEngine(Engine):
    def _internal_run(self) -> State:
        self.should_terminate = self.should_terminate_single_epoch = False
        self._init_timers(self.state)
        try:
            start_time = time.time()
            self._fire_event(Events.STARTED)
            while self.state.epoch < self.state.max_epochs and not self.should_terminate:
                self.state.epoch += 1
                self._fire_event(Events.EPOCH_STARTED)

                if self._dataloader_iter is None:
                    self._setup_engine()

                time_taken = self._run_once_on_dataset()
                # time is available for handlers but must be update after fire
                self.state.times[Events.EPOCH_COMPLETED.name] = time_taken
                handlers_start_time = time.time()
                if self.should_terminate:
                    self._fire_event(Events.TERMINATE)
                else:
                    self._fire_event(Events.EPOCH_COMPLETED)
                time_taken += time.time() - handlers_start_time
                # update time wrt handlers
                self.state.times[Events.EPOCH_COMPLETED.name] = time_taken
                hours, mins, secs = _to_hours_mins_secs(time_taken)
                self.logger.info(
                    "Epoch[%s] Complete. Time taken: %02d:%02d:%02d" % (self.state.epoch, hours, mins, secs)
                )
                if self.should_terminate:
                    break
                yield self.state

            time_taken = time.time() - start_time
            # time is available for handlers but must be update after fire
            self.state.times[Events.COMPLETED.name] = time_taken
            handlers_start_time = time.time()
            self._fire_event(Events.COMPLETED)
            time_taken += time.time() - handlers_start_time
            # update time wrt handlers
            self.state.times[Events.COMPLETED.name] = time_taken
            hours, mins, secs = _to_hours_mins_secs(time_taken)
            self.logger.info("Engine run complete. Time taken: %02d:%02d:%02d" % (hours, mins, secs))

        except BaseException as e:
            self._dataloader_iter = None
            self.logger.error("Engine run is terminating due to exception: %s.", str(e))
            self._handle_exception(e)

        self._dataloader_iter = None
        return self.state


if __name__ == '__main__':
    iterable_engine = IterableEngine(lambda x, y: 0.)
    iterable_engine.add_event_handler(Events.STARTED, lambda x: print("started"))
    iterable_engine.add_event_handler(Events.EPOCH_STARTED, lambda x: print("epoch started"))
    iterable_engine.add_event_handler(Events.EPOCH_COMPLETED, lambda x: print("epoch completed"))
    iterable_engine.add_event_handler(Events.COMPLETED, lambda x: print("completed"))

    epoch_iterator = iterable_engine.run([1], max_epochs=3)
    for state in epoch_iterator:
        print("This is outside engine.run")

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions