1
+ import warnings
1
2
from copy import deepcopy
2
3
from typing import Optional , Union
3
4
4
5
import torch .nn as nn
5
6
6
7
from ignite .engine import CallableEventWithFilter , Engine , Events , EventsList
8
+ from ignite .handlers .param_scheduler import BaseParamScheduler
9
+ from ignite .handlers .state_param_scheduler import LambdaStateScheduler
7
10
8
11
__all__ = ["EMAHandler" ]
9
12
10
13
14
+ class EMAWarmUp :
15
+ def __init__ (self , momentum_warmup : float , warmup_iters : int , momentum : float ) -> None :
16
+ self .momentum_warmup = momentum_warmup
17
+ self .warmup_iters = warmup_iters
18
+ self .momentum = momentum
19
+
20
+ def __call__ (self , event_index : int ) -> float :
21
+ denominator = max (1 , self .warmup_iters - 1 )
22
+ curr_momentum = self .momentum_warmup + (self .momentum - self .momentum_warmup ) * (event_index - 1 ) / denominator
23
+ if self .momentum >= self .momentum_warmup :
24
+ return min (self .momentum , curr_momentum )
25
+ else :
26
+ return max (self .momentum , curr_momentum )
27
+
28
+
11
29
class EMAHandler :
12
30
r"""Exponential moving average (EMA) handler can be used to compute a smoothed version of model.
13
31
The EMA model is updated as follows:
14
32
15
33
.. math:: \theta_{\text{EMA}, t+1} = (1 - \lambda) \cdot \theta_{\text{EMA}, t} + \lambda \cdot \theta_{t}
16
34
17
35
where :math:`\theta_{\text{EMA}, t}` and :math:`\theta_{t}` are the EMA weights and online model weights at
18
- :math:`t`-th iteration, respectively; :math:`\lambda` is the update momentum. The handler allows for linearly
19
- warming up the momentum in the beginning when training process is not stable. Current momentum can be retrieved
36
+ :math:`t`-th iteration, respectively; :math:`\lambda` is the update momentum. Current momentum can be retrieved
20
37
from ``Engine.state.ema_momentum``.
21
38
22
39
Args:
23
40
model: the online model for which an EMA model will be computed. If ``model`` is ``DataParallel`` or
24
41
``DistributedDataParallel``, the EMA smoothing will be applied to ``model.module`` .
25
42
momentum: the update momentum after warmup phase, should be float in range :math:`\left(0, 1 \right)`.
26
- momentum_warmup: the initial update momentum during warmup phase, the value should be smaller than
27
- ``momentum``. Momentum will linearly increase from this value to ``momentum`` in ``warmup_iters``
28
- iterations. If ``None``, no warmup will be performed.
29
- warmup_iters: iterations of warmup. If ``None``, no warmup will be performed.
43
+ momentum_warmup: the initial update momentum during warmup phase.
44
+ warmup_iters: iterations of warmup.
30
45
31
46
Attributes:
32
47
ema_model: the exponential moving averaged model.
33
48
model: the online model that is tracked by EMAHandler. It is ``model.module`` if ``model`` in
34
49
the initialization method is an instance of ``DistributedDataParallel``.
35
- momentum: the update momentum after warmup phase.
36
- momentum_warmup: the initial update momentum.
37
- warmup_iters: number of warmup iterations.
50
+ momentum: the update momentum.
38
51
39
52
Note:
40
53
The EMA model is already in ``eval`` mode. If model in the arguments is an ``nn.Module`` or
@@ -56,8 +69,7 @@ class EMAHandler:
56
69
device = torch.device("cuda:0")
57
70
model = nn.Linear(2, 1).to(device)
58
71
# update the ema every 5 iterations
59
- ema_handler = EMAHandler(
60
- model, momentum=0.0002, momentum_warmup=0.0001, warmup_iters=10000)
72
+ ema_handler = EMAHandler(model, momentum=0.0002)
61
73
# get the ema model, which is an instance of nn.Module
62
74
ema_model = ema_handler.ema_model
63
75
trainer = Engine(train_step_fn)
@@ -89,6 +101,19 @@ def run_validation(engine):
89
101
90
102
trainer.run(...)
91
103
104
+ The following example shows how to perform warm-up to the EMA momentum:
105
+
106
+ .. code-block:: python
107
+
108
+ device = torch.device("cuda:0")
109
+ model = nn.Linear(2, 1).to(device)
110
+ # linearly change the EMA momentum from 0.2 to 0.002 in the first 100 iterations,
111
+ # then keep a constant EMA momentum of 0.002 afterwards
112
+ ema_handler = EMAHandler(model, momentum=0.002, momentum_warmup=0.2, warmup_iters=100)
113
+ engine = Engine(step_fn)
114
+ ema_handler.attach(engine, name="ema_momentum")
115
+ engine.run(...)
116
+
92
117
The following example shows how to attach two handlers to the same trainer:
93
118
94
119
.. code-block:: python
@@ -125,25 +150,19 @@ def __init__(
125
150
momentum_warmup : Optional [float ] = None ,
126
151
warmup_iters : Optional [int ] = None ,
127
152
) -> None :
128
- if momentum_warmup is not None and not 0 < momentum_warmup < 1 :
129
- raise ValueError (f"Invalid momentum_warmup: { momentum_warmup } " )
130
153
if not 0 < momentum < 1 :
131
154
raise ValueError (f"Invalid momentum: { momentum } " )
132
- if momentum_warmup is not None and not momentum_warmup <= momentum :
133
- raise ValueError (
134
- f"momentum_warmup should be less than or equal to momentum, but got "
135
- f"momentum_warmup: { momentum_warmup } and momentum: { momentum } "
136
- )
137
- if warmup_iters is not None and not (isinstance (warmup_iters , int ) and warmup_iters > 0 ):
138
- raise ValueError (f"Invalid warmup_iters: { warmup_iters } " )
155
+ self .momentum = momentum
156
+ self ._momentum_lambda_obj : Optional [EMAWarmUp ] = None
157
+ if momentum_warmup is not None and warmup_iters is not None :
158
+ self .momentum_scheduler : Optional [BaseParamScheduler ] = None
159
+ self ._momentum_lambda_obj = EMAWarmUp (momentum_warmup , warmup_iters , momentum )
160
+
139
161
if not isinstance (model , nn .Module ):
140
162
raise ValueError (
141
163
f"model should be an instance of nn.Module or its subclasses, but got"
142
164
f"model: { model .__class__ .__name__ } "
143
165
)
144
- self .momentum_warmup = momentum_warmup
145
- self .momentum = momentum
146
- self .warmup_iters = warmup_iters
147
166
148
167
if isinstance (model , nn .parallel .DistributedDataParallel ):
149
168
model = model .module
@@ -154,22 +173,6 @@ def __init__(
154
173
param .detach_ ()
155
174
self .ema_model .eval ()
156
175
157
- def _get_momentum (self , curr_iter : int ) -> float :
158
- """Get current momentum, `curr_iter` should be 1-based. When `curr_iter = 1`, `momentum =
159
- self.momentum_warmup`; when `curr_iter >= self.warmup_iters`, `momentum = self.momentum`"""
160
-
161
- # TODO: use ignite's parameter scheduling, see also GitHub issue #2090
162
- if curr_iter < 1 :
163
- raise ValueError (f"curr_iter should be at least 1, but got { curr_iter } ." )
164
-
165
- # no warmup
166
- if self .momentum_warmup is None or self .warmup_iters is None :
167
- return self .momentum
168
-
169
- denominator = max (1 , self .warmup_iters - 1 )
170
- momentum = self .momentum_warmup + (self .momentum - self .momentum_warmup ) * (curr_iter - 1 ) / denominator
171
- return min (self .momentum , momentum )
172
-
173
176
def _update_ema_model (self , engine : Engine , name : str ) -> None :
174
177
"""Update weights of ema model"""
175
178
momentum = getattr (engine .state , name )
@@ -179,36 +182,47 @@ def _update_ema_model(self, engine: Engine, name: str) -> None:
179
182
for ema_b , model_b in zip (self .ema_model .buffers (), self .model .buffers ()):
180
183
ema_b .data = model_b .data
181
184
182
- def _update_ema_momentum (self , engine : Engine , name : str ) -> None :
183
- """Update momentum in engine.state"""
184
- curr_iter = engine .state .iteration
185
- momentum = self ._get_momentum (curr_iter )
186
- setattr (engine .state , name , momentum )
187
-
188
185
def attach (
189
186
self ,
190
187
engine : Engine ,
191
188
name : str = "ema_momentum" ,
189
+ warn_if_exists : bool = True ,
192
190
event : Union [str , Events , CallableEventWithFilter , EventsList ] = Events .ITERATION_COMPLETED ,
193
191
) -> None :
194
192
"""Attach the handler to engine. After the handler is attached, the ``Engine.state`` will add an new attribute
195
- with name ``name``. Then, current momentum can be retrieved by from ``Engine.state`` when the engine runs.
193
+ with name ``name`` if the attribute does not exist. Then, the current momentum can be retrieved from
194
+ ``Engine.state`` when the engine runs.
195
+
196
+
197
+ Note:
198
+ There are two cases where a momentum with name ``name`` already exists: 1. the engine has loaded its
199
+ state dict after resuming. In this case, there is no need to initialize the momentum again, and users
200
+ can set ``warn_if_exists`` to False to suppress the warning message; 2. another handler has created
201
+ a state attribute with the same name. In this case, users should choose another name for the ema momentum.
202
+
196
203
197
204
Args:
198
205
engine: trainer to which the handler will be attached.
199
206
name: attribute name for retrieving EMA momentum from ``Engine.state``. It should be a unique name since a
200
207
trainer can have multiple EMA handlers.
208
+ warn_if_exists: if True, a warning will be thrown if the momentum with name ``name`` already exists.
201
209
event: event when the EMA momentum and EMA model are updated.
202
210
203
211
"""
204
212
if hasattr (engine .state , name ):
205
- raise ValueError (
206
- f"Attribute: '{ name } ' is already in Engine.state. Thus it might be "
207
- f"overridden by other EMA handlers. Please select another name."
208
- )
209
-
210
- setattr (engine .state , name , 0.0 )
211
-
212
- # first update momentum, then update ema model
213
- engine .add_event_handler (event , self ._update_ema_momentum , name )
213
+ if warn_if_exists :
214
+ warnings .warn (
215
+ f"Attribute '{ name } ' already exists in Engine.state. It might because 1. the engine has loaded its "
216
+ f"state dict or 2. { name } is already created by other handlers. Turn off this warning by setting"
217
+ f"warn_if_exists to False." ,
218
+ category = UserWarning ,
219
+ )
220
+ else :
221
+ setattr (engine .state , name , self .momentum )
222
+
223
+ if self ._momentum_lambda_obj is not None :
224
+ self .momentum_scheduler = LambdaStateScheduler (self ._momentum_lambda_obj , param_name = "ema_momentum" )
225
+
226
+ # first update the momentum and then update the EMA model
227
+ self .momentum_scheduler .attach (engine , event )
214
228
engine .add_event_handler (event , self ._update_ema_model , name )
0 commit comments