Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 33 additions & 29 deletions ignite/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,17 @@ def supervised_training_step(
Returns:
Callable: update function.

Example::
Examples:
.. code-block:: python

from ignite.engine import Engine, supervised_training_step
from ignite.engine import Engine, supervised_training_step

model = ...
optimizer = ...
loss_fn = ...
model = ...
optimizer = ...
loss_fn = ...

update_fn = supervised_training_step(model, optimizer, loss_fn, 'cuda')
trainer = Engine(update_fn)
update_fn = supervised_training_step(model, optimizer, loss_fn, 'cuda')
trainer = Engine(update_fn)

.. versionadded:: 0.4.5
"""
Expand Down Expand Up @@ -128,17 +129,18 @@ def supervised_training_step_amp(
Returns:
Callable: update function

Example::
Examples:
.. code-block:: python

from ignite.engine import Engine, supervised_training_step_amp
from ignite.engine import Engine, supervised_training_step_amp

model = ...
optimizer = ...
loss_fn = ...
scaler = torch.cuda.amp.GradScaler(2**10)
model = ...
optimizer = ...
loss_fn = ...
scaler = torch.cuda.amp.GradScaler(2**10)

update_fn = supervised_training_step_amp(model, optimizer, loss_fn, 'cuda', scaler=scaler)
trainer = Engine(update_fn)
update_fn = supervised_training_step_amp(model, optimizer, loss_fn, 'cuda', scaler=scaler)
trainer = Engine(update_fn)

.. versionadded:: 0.4.5
"""
Expand Down Expand Up @@ -195,16 +197,17 @@ def supervised_training_step_apex(
Returns:
Callable: update function.

Example::
Examples:
.. code-block:: python

from ignite.engine import Engine, supervised_training_step_apex
from ignite.engine import Engine, supervised_training_step_apex

model = ...
optimizer = ...
loss_fn = ...
model = ...
optimizer = ...
loss_fn = ...

update_fn = supervised_training_step_apex(model, optimizer, loss_fn, 'cuda')
trainer = Engine(update_fn)
update_fn = supervised_training_step_apex(model, optimizer, loss_fn, 'cuda')
trainer = Engine(update_fn)

.. versionadded:: 0.4.5
"""
Expand Down Expand Up @@ -256,16 +259,17 @@ def supervised_training_step_tpu(
Returns:
Callable: update function.

Example::
Examples:
.. code-block:: python

from ignite.engine import Engine, supervised_training_step_tpu
from ignite.engine import Engine, supervised_training_step_tpu

model = ...
optimizer = ...
loss_fn = ...
model = ...
optimizer = ...
loss_fn = ...

update_fn = supervised_training_step_tpu(model, optimizer, loss_fn, 'xla')
trainer = Engine(update_fn)
update_fn = supervised_training_step_tpu(model, optimizer, loss_fn, 'xla')
trainer = Engine(update_fn)

.. versionadded:: 0.4.5
"""
Expand Down
18 changes: 9 additions & 9 deletions ignite/engine/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,21 @@ class ReproducibleBatchSampler(BatchSampler):
"""Reproducible batch sampler. This class internally iterates and stores indices of the input batch sampler.
This helps to start providing data batches from an iteration in a deterministic way.

Example:
Args:
batch_sampler: batch sampler same as used with `torch.utils.data.DataLoader`.
start_iteration: optional start iteration.

Examples:
Setup dataloader with `ReproducibleBatchSampler` and start providing data batches from an iteration

.. code-block:: python
.. code-block:: python

from ignite.engine.deterministic import update_dataloader
from ignite.engine.deterministic import update_dataloader

dataloader = update_dataloader(dataloader, ReproducibleBatchSampler(dataloader.batch_sampler))
# rewind dataloader to a specific iteration:
dataloader.batch_sampler.start_iteration = start_iteration
dataloader = update_dataloader(dataloader, ReproducibleBatchSampler(dataloader.batch_sampler))
# rewind dataloader to a specific iteration:
dataloader.batch_sampler.start_iteration = start_iteration

Args:
batch_sampler: batch sampler same as used with `torch.utils.data.DataLoader`.
start_iteration: optional start iteration.
"""

def __init__(self, batch_sampler: BatchSampler, start_iteration: Optional[int] = None):
Expand Down
135 changes: 66 additions & 69 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class Engine(Serializable):
last_event_name: last event name triggered by the engine.

Examples:

Create a basic trainer

.. code-block:: python
Expand Down Expand Up @@ -158,63 +157,63 @@ def register_events(
or an object derived from :class:`~ignite.engine.events.EventEnum`. See example below.
event_to_attr: A dictionary to map an event to a state attribute.

Example usage:
Examples:
.. code-block:: python

.. code-block:: python
from ignite.engine import Engine, Events, EventEnum

from ignite.engine import Engine, Events, EventEnum
class CustomEvents(EventEnum):
FOO_EVENT = "foo_event"
BAR_EVENT = "bar_event"

class CustomEvents(EventEnum):
FOO_EVENT = "foo_event"
BAR_EVENT = "bar_event"
def process_function(e, batch):
# ...
trainer.fire_event("bwd_event")
loss.backward()
# ...
trainer.fire_event("opt_event")
optimizer.step()

def process_function(e, batch):
# ...
trainer.fire_event("bwd_event")
loss.backward()
# ...
trainer.fire_event("opt_event")
optimizer.step()
trainer = Engine(process_function)
trainer.register_events(*CustomEvents)
trainer.register_events("bwd_event", "opt_event")

trainer = Engine(process_function)
trainer.register_events(*CustomEvents)
trainer.register_events("bwd_event", "opt_event")
@trainer.on(Events.EPOCH_COMPLETED)
def trigger_custom_event():
if required(...):
trainer.fire_event(CustomEvents.FOO_EVENT)
else:
trainer.fire_event(CustomEvents.BAR_EVENT)

@trainer.on(Events.EPOCH_COMPLETED)
def trigger_custom_event():
if required(...):
trainer.fire_event(CustomEvents.FOO_EVENT)
else:
trainer.fire_event(CustomEvents.BAR_EVENT)
@trainer.on(CustomEvents.FOO_EVENT)
def do_foo_op():
# ...

@trainer.on(CustomEvents.FOO_EVENT)
def do_foo_op():
# ...
@trainer.on(CustomEvents.BAR_EVENT)
def do_bar_op():
# ...

@trainer.on(CustomEvents.BAR_EVENT)
def do_bar_op():
# ...
Example with State Attribute:

Example with State Attribute:

.. code-block:: python
.. code-block:: python

from enum import Enum
from ignite.engine import Engine, EventEnum
from enum import Enum
from ignite.engine import Engine, EventEnum

class TBPTT_Events(EventEnum):
TIME_ITERATION_STARTED = "time_iteration_started"
TIME_ITERATION_COMPLETED = "time_iteration_completed"
class TBPTT_Events(EventEnum):
TIME_ITERATION_STARTED = "time_iteration_started"
TIME_ITERATION_COMPLETED = "time_iteration_completed"

TBPTT_event_to_attr = {
TBPTT_Events.TIME_ITERATION_STARTED: 'time_iteration',
TBPTT_Events.TIME_ITERATION_COMPLETED: 'time_iteration'
}
TBPTT_event_to_attr = {
TBPTT_Events.TIME_ITERATION_STARTED: 'time_iteration',
TBPTT_Events.TIME_ITERATION_COMPLETED: 'time_iteration'
}

engine = Engine(process_function)
engine.register_events(*TBPTT_Events, event_to_attr=TBPTT_event_to_attr)
engine.run(data)
# engine.state contains an attribute time_iteration, which can be accessed using engine.state.time_iteration
engine = Engine(process_function)
engine.register_events(*TBPTT_Events, event_to_attr=TBPTT_event_to_attr)
engine.run(data)
# engine.state contains an attribute time_iteration, which can be accessed
# using engine.state.time_iteration
"""
if not (event_to_attr is None or isinstance(event_to_attr, dict)):
raise ValueError(f"Expected event_to_attr to be dictionary. Got {type(event_to_attr)}.")
Expand Down Expand Up @@ -266,24 +265,23 @@ def add_event_handler(self, event_name: Any, handler: Callable, *args: Any, **kw
Note that other arguments can be passed to the handler in addition to the `*args` and `**kwargs`
passed here, for example during :attr:`~ignite.engine.events.Events.EXCEPTION_RAISED`.

Example usage:

.. code-block:: python
Examples:
.. code-block:: python

engine = Engine(process_function)
engine = Engine(process_function)

def print_epoch(engine):
print(f"Epoch: {engine.state.epoch}")
def print_epoch(engine):
print(f"Epoch: {engine.state.epoch}")

engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch)
engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch)

events_list = Events.EPOCH_COMPLETED | Events.COMPLETED
events_list = Events.EPOCH_COMPLETED | Events.COMPLETED

def execute_something():
# do some thing not related to engine
pass
def execute_something():
# do some thing not related to engine
pass

engine.add_event_handler(events_list, execute_something)
engine.add_event_handler(events_list, execute_something)

Note:
Since v0.3.0, Events become more flexible and allow to pass an event filter to the Engine.
Expand Down Expand Up @@ -379,20 +377,19 @@ def on(self, event_name: Any, *args: Any, **kwargs: Any) -> Callable:
args: optional args to be passed to `handler`.
kwargs: optional keyword args to be passed to `handler`.

Example usage:

.. code-block:: python
Examples:
.. code-block:: python

engine = Engine(process_function)
engine = Engine(process_function)

@engine.on(Events.EPOCH_COMPLETED)
def print_epoch():
print(f"Epoch: {engine.state.epoch}")
@engine.on(Events.EPOCH_COMPLETED)
def print_epoch():
print(f"Epoch: {engine.state.epoch}")

@engine.on(Events.EPOCH_COMPLETED | Events.COMPLETED)
def execute_something():
# do some thing not related to engine
pass
@engine.on(Events.EPOCH_COMPLETED | Events.COMPLETED)
def execute_something():
# do some thing not related to engine
pass
"""

def decorator(f: Callable) -> Callable:
Expand Down Expand Up @@ -572,7 +569,7 @@ def set_data(self, data: Union[Iterable, DataLoader]) -> None:
Args:
data: Collection of batches allowing repeated iteration (e.g., list or `DataLoader`).

Example usage:
Examples:
User can switch data provider during the training:

.. code-block:: python
Expand Down
24 changes: 13 additions & 11 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,10 @@ def __or__(self, other: Any) -> "EventsList":

class EventEnum(CallableEventWithFilter, Enum): # type: ignore[misc]
"""Base class for all :class:`~ignite.engine.events.Events`. User defined custom events should also inherit
this class. For example, Custom events based on the loss calculation and backward pass can be created as follows:
this class.

Examples:
Custom events based on the loss calculation and backward pass can be created as follows:

.. code-block:: python

Expand Down Expand Up @@ -436,20 +439,19 @@ class RemovableEventHandle:
handler: Registered event handler, stored as weakref.
engine: Target engine, stored as weakref.

Example usage:

.. code-block:: python
Examples:
.. code-block:: python

engine = Engine()
engine = Engine()

def print_epoch(engine):
print(f"Epoch: {engine.state.epoch}")
def print_epoch(engine):
print(f"Epoch: {engine.state.epoch}")

with engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch):
# print_epoch handler registered for a single run
engine.run(data)
with engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch):
# print_epoch handler registered for a single run
engine.run(data)

# print_epoch handler is now unregistered
# print_epoch handler is now unregistered
"""

def __init__(
Expand Down
Loading