Skip to content

Commit a92619b

Browse files
authored
Merge pull request #3084 from plonerma/pluggable_trainer
Proposal: Pluggable `ModelTrainer` train function
2 parents ea46696 + 640c694 commit a92619b

19 files changed

+454
-833
lines changed

flair/trainers/plugins/__init__.py

+1-18
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
AmpPlugin,
33
CheckpointPlugin,
44
ModelCardPlugin,
5-
RegularLoggingPlugin,
65
SchedulerPlugin,
76
SWAPlugin,
87
WeightExtractorPlugin,
@@ -13,51 +12,35 @@
1312
MetricHistoryPlugin,
1413
TensorboardLogger,
1514
)
16-
from flair.trainers.plugins.metrics import (
17-
BasicEvaluationPlugin,
18-
MetricBasePlugin,
19-
MetricName,
20-
MetricRecord,
21-
TrainingBehaviorPlugin,
22-
)
2315

2416
from .base import BasePlugin, Pluggable, TrainerPlugin, TrainingInterrupt
17+
from .metric_records import MetricName, MetricRecord
2518

2619
default_plugins = [
27-
BasicEvaluationPlugin,
28-
TrainingBehaviorPlugin,
29-
RegularLoggingPlugin,
30-
AmpPlugin,
3120
CheckpointPlugin,
3221
ModelCardPlugin,
3322
SchedulerPlugin,
34-
SWAPlugin,
3523
WeightExtractorPlugin,
3624
LossFilePlugin,
3725
MetricHistoryPlugin,
38-
TensorboardLogger,
3926
LogFilePlugin,
4027
]
4128

4229
__all__ = [
4330
"AmpPlugin",
4431
"CheckpointPlugin",
4532
"ModelCardPlugin",
46-
"RegularLoggingPlugin",
4733
"SchedulerPlugin",
4834
"SWAPlugin",
4935
"WeightExtractorPlugin",
5036
"LogFilePlugin",
5137
"LossFilePlugin",
5238
"MetricHistoryPlugin",
5339
"TensorboardLogger",
54-
"BasicEvaluationPlugin",
55-
"TrainingBehaviorPlugin",
5640
"BasePlugin",
5741
"Pluggable",
5842
"TrainerPlugin",
5943
"TrainingInterrupt",
60-
"MetricBasePlugin",
6144
"MetricName",
6245
"MetricRecord",
6346
]

flair/trainers/plugins/base.py

+17-44
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import (
77
Callable,
88
Dict,
9-
Iterable,
109
Iterator,
1110
List,
1211
NewType,
@@ -36,7 +35,7 @@ class Pluggable:
3635

3736
valid_events: Optional[Set[EventIdenifier]] = None
3837

39-
def __init__(self, *, plugins: Sequence[PluginArgument] = None):
38+
def __init__(self, *, plugins: Sequence[PluginArgument] = []):
4039
"""Initialize a `Pluggable`.
4140
4241
:param plugins: Plugins which should be attached to this `Pluggable`.
@@ -50,14 +49,13 @@ def __init__(self, *, plugins: Sequence[PluginArgument] = None):
5049
self._event_queue: Queue = Queue()
5150
self._processing_events = False
5251

53-
if plugins is not None:
54-
for plugin in plugins:
55-
if isclass(plugin):
56-
# instantiate plugin
57-
plugin = plugin()
52+
for plugin in plugins:
53+
if isclass(plugin):
54+
# instantiate plugin
55+
plugin = plugin()
5856

59-
plugin = cast("BasePlugin", plugin)
60-
plugin.attach_to(self)
57+
plugin = cast("BasePlugin", plugin)
58+
plugin.attach_to(self)
6159

6260
@property
6361
def plugins(self):
@@ -72,7 +70,7 @@ def validate_event(self, *events: EventIdenifier):
7270

7371
if self.valid_events is not None:
7472
if event not in self.valid_events:
75-
raise RuntimeError(f"Event '{event}' not recognized (available {self.valid_events})")
73+
raise RuntimeError(f"Event '{event}' not recognized. Available: {', '.join(self.valid_events)}")
7674
return event
7775

7876
def register_hook(self, func: Callable, *events: EventIdenifier):
@@ -92,31 +90,23 @@ def register_hook(self, func: Callable, *events: EventIdenifier):
9290
self._hook_handles[event][handle.id] = handle
9391
return handle
9492

95-
def dispatch(self, event: EventIdenifier, *args, **kwargs) -> dict:
93+
def dispatch(self, event: EventIdenifier, *args, **kwargs) -> None:
9694
"""Call all functions hooked to a certain event."""
9795
self.validate_event(event)
9896

99-
events_return_value: dict = {}
100-
self._event_queue.put((event, args, kwargs, events_return_value))
97+
self._event_queue.put((event, args, kwargs))
10198

10299
if not self._processing_events:
103100
self._processing_events = True
104101

105102
while not self._event_queue.empty():
106-
event, args, kwargs, combined_return_values = self._event_queue.get()
103+
event, args, kwargs = self._event_queue.get()
107104

108105
for hook in self._hook_handles[event].values():
109-
returned = hook(*args, **kwargs)
110-
111-
if returned is not None:
112-
combined_return_values.update(returned)
106+
hook(*args, **kwargs)
113107

114108
self._processing_events = False
115109

116-
# this dict may be empty and will be complete once all events have been
117-
# processed
118-
return events_return_value
119-
120110
def remove_hook(self, handle: "HookHandle"):
121111
"""Remove a hook handle from this instance."""
122112
for event in handle.events:
@@ -146,6 +136,10 @@ def id(self) -> HookHandleId:
146136
"""Return the id of this `HookHandle`."""
147137
return self._id
148138

139+
@property
140+
def func_name(self):
141+
return self._func.__qualname__
142+
149143
@property
150144
def events(self) -> Iterator[EventIdenifier]:
151145
"""Return iterator of events whis `HookHandle` is registered for."""
@@ -165,7 +159,7 @@ def __call__(self, *args, **kw):
165159
for name in kw.keys():
166160
if name not in sig.parameters:
167161
raise TypeError(
168-
f"Hook callback {self._func.__qualname__}() does not accept keyword argument '{name}'"
162+
f"Hook callback {self.func_name}() does not accept keyword argument '{name}'"
169163
) from err
170164

171165
raise err
@@ -174,10 +168,6 @@ def __call__(self, *args, **kw):
174168
class BasePlugin:
175169
"""Base class for all plugins."""
176170

177-
provided_events: Optional[Set[EventIdenifier]] = None
178-
179-
dependencies: Iterable[Type["BasePlugin"]] = ()
180-
181171
def __init__(self):
182172
"""Initialize the base plugin."""
183173
self._hook_handles: List[HookHandle] = []
@@ -190,23 +180,6 @@ def attach_to(self, pluggable: Pluggable):
190180

191181
self._pluggable = pluggable
192182

193-
for dep in self.dependencies:
194-
dep_satisfied = False
195-
196-
for plugin in pluggable.plugins:
197-
if isinstance(plugin, dep):
198-
# there is already a plugin which satisfies this dependency
199-
dep_satisfied = True
200-
break
201-
202-
if not dep_satisfied:
203-
# create a plugin of this type and attach it to the trainer
204-
dep_plugin = dep()
205-
dep_plugin.attach_to(pluggable)
206-
207-
if self.provided_events is not None and pluggable.valid_events is not None:
208-
pluggable.valid_events = pluggable.valid_events | self.provided_events
209-
210183
pluggable.append_plugin(self)
211184

212185
# go through all attributes

flair/trainers/plugins/functional/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from .amp import AmpPlugin
22
from .checkpoints import CheckpointPlugin
3-
from .logging import RegularLoggingPlugin
43
from .model_card import ModelCardPlugin
54
from .scheduler import SchedulerPlugin
65
from .swa import SWAPlugin
@@ -13,5 +12,4 @@
1312
"SchedulerPlugin",
1413
"SWAPlugin",
1514
"WeightExtractorPlugin",
16-
"RegularLoggingPlugin",
1715
]

flair/trainers/plugins/functional/amp.py

+25-29
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,6 @@
22

33
from flair.trainers.plugins.base import TrainerPlugin
44

5-
try:
6-
from apex import amp
7-
except ImportError:
8-
amp = None
9-
105

116
class AmpPlugin(TrainerPlugin):
127
"""
@@ -15,52 +10,53 @@ class AmpPlugin(TrainerPlugin):
1510

1611
def __init__(self):
1712
super().__init__()
18-
self.use = None # TODO: can be removed
19-
self.opt_level = None # TODO: I think this also since only used in 1 place
13+
2014
self.wrapped_backward = None
15+
self.amp = None
2116

2217
@TrainerPlugin.hook
23-
def before_training_setup(self, use_amp, amp_opt_level, **kw):
24-
self.use = use_amp # TODO: can be removed
25-
self.opt_level = amp_opt_level
26-
27-
if self.use:
28-
if sys.version_info < (3, 0):
29-
raise RuntimeError("Apex currently only supports Python 3. Aborting.")
30-
if amp is None:
31-
raise RuntimeError(
32-
"Failed to import apex. Please install apex from "
33-
"https://www.github.com/nvidia/apex "
34-
"to enable mixed-precision training."
35-
)
18+
def before_training_setup(self, **kw):
19+
if sys.version_info < (3, 0):
20+
raise RuntimeError("Apex currently only supports Python 3. Aborting.")
21+
22+
try:
23+
from apex import amp
24+
25+
self.amp = amp
26+
except ImportError as exc:
27+
raise RuntimeError(
28+
"Failed to import apex. Please install apex from "
29+
"https://www.github.com/nvidia/apex "
30+
"to enable mixed-precision training."
31+
) from exc
3632

3733
def detach(self, *args, **kwargs):
3834
# TODO: what does this do?
3935
super().detach(*args, **kwargs)
4036

37+
# unwrap trainer backward function
4138
self.trainer.backward = self.wrapped_backward
4239
self.wrapped_backward = None
4340

4441
def backward(self, loss):
42+
assert self.amp is not None
4543
optimizer = self.trainer.optimizer
4644

47-
if self.use: # TODO: can be removed
48-
with amp.scale_loss(loss, optimizer) as scaled_loss:
49-
scaled_loss.backward()
45+
with self.amp.scale_loss(loss, optimizer) as scaled_loss:
46+
scaled_loss.backward()
5047

5148
@TrainerPlugin.hook
52-
def after_optimizer_setup(self, **kw):
49+
def after_optimizer_setup(self, amp_opt_level, **kw):
5350
"""
5451
Wraps with AMP
5552
:param kw:
5653
:return:
5754
"""
5855
optimizer = self.trainer.optimizer
5956

60-
if self.use: # TODO: can be removed
61-
self.trainer.model, self.trainer.optimizer = amp.initialize(self.model, optimizer, opt_level=self.opt_level)
57+
self.trainer.model, self.trainer.optimizer = self.amp.initialize(self.model, optimizer, opt_level=amp_opt_level)
6258

63-
# replace trainers backward function
64-
self.wrapped_backward = self.trainer.backward
59+
# replace trainers backward function
60+
self.wrapped_backward = self.trainer.backward
6561

66-
self.trainer.backward = self.backward
62+
self.trainer.backward = self.backward

flair/trainers/plugins/functional/best_model.py

-80
This file was deleted.

0 commit comments

Comments
 (0)