Skip to content

Commit 62742a9

Browse files
Yawei Livfdev-5
andauthored
Ema momentum (#2333)
* deprecate warmup function in ema_handler.py Signed-off-by: sandylaker <yawei.li@tum.de> * keep the previous API but throw warnings Signed-off-by: sandylaker <yawei.li@tum.de> * revert changes to state_param_scheduler.py and test_state_param_scheduler.py Signed-off-by: sandylaker <yawei.li@tum.de> * use `LambdaStateScheduler` to schedule EMA momentum * fix docs * fix mypy * Apply suggestions from code review Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent 59bff0b commit 62742a9

File tree

3 files changed

+123
-113
lines changed

3 files changed

+123
-113
lines changed

ignite/handlers/ema_handler.py

Lines changed: 69 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,53 @@
1+
import warnings
12
from copy import deepcopy
23
from typing import Optional, Union
34

45
import torch.nn as nn
56

67
from ignite.engine import CallableEventWithFilter, Engine, Events, EventsList
8+
from ignite.handlers.param_scheduler import BaseParamScheduler
9+
from ignite.handlers.state_param_scheduler import LambdaStateScheduler
710

811
__all__ = ["EMAHandler"]
912

1013

14+
class EMAWarmUp:
15+
def __init__(self, momentum_warmup: float, warmup_iters: int, momentum: float) -> None:
16+
self.momentum_warmup = momentum_warmup
17+
self.warmup_iters = warmup_iters
18+
self.momentum = momentum
19+
20+
def __call__(self, event_index: int) -> float:
21+
denominator = max(1, self.warmup_iters - 1)
22+
curr_momentum = self.momentum_warmup + (self.momentum - self.momentum_warmup) * (event_index - 1) / denominator
23+
if self.momentum >= self.momentum_warmup:
24+
return min(self.momentum, curr_momentum)
25+
else:
26+
return max(self.momentum, curr_momentum)
27+
28+
1129
class EMAHandler:
1230
r"""Exponential moving average (EMA) handler can be used to compute a smoothed version of model.
1331
The EMA model is updated as follows:
1432
1533
.. math:: \theta_{\text{EMA}, t+1} = (1 - \lambda) \cdot \theta_{\text{EMA}, t} + \lambda \cdot \theta_{t}
1634
1735
where :math:`\theta_{\text{EMA}, t}` and :math:`\theta_{t}` are the EMA weights and online model weights at
18-
:math:`t`-th iteration, respectively; :math:`\lambda` is the update momentum. The handler allows for linearly
19-
warming up the momentum in the beginning when training process is not stable. Current momentum can be retrieved
36+
:math:`t`-th iteration, respectively; :math:`\lambda` is the update momentum. Current momentum can be retrieved
2037
from ``Engine.state.ema_momentum``.
2138
2239
Args:
2340
model: the online model for which an EMA model will be computed. If ``model`` is ``DataParallel`` or
2441
``DistributedDataParallel``, the EMA smoothing will be applied to ``model.module`` .
2542
momentum: the update momentum after warmup phase, should be float in range :math:`\left(0, 1 \right)`.
26-
momentum_warmup: the initial update momentum during warmup phase, the value should be smaller than
27-
``momentum``. Momentum will linearly increase from this value to ``momentum`` in ``warmup_iters``
28-
iterations. If ``None``, no warmup will be performed.
29-
warmup_iters: iterations of warmup. If ``None``, no warmup will be performed.
43+
momentum_warmup: the initial update momentum during warmup phase.
44+
warmup_iters: iterations of warmup.
3045
3146
Attributes:
3247
ema_model: the exponential moving averaged model.
3348
model: the online model that is tracked by EMAHandler. It is ``model.module`` if ``model`` in
3449
the initialization method is an instance of ``DistributedDataParallel``.
35-
momentum: the update momentum after warmup phase.
36-
momentum_warmup: the initial update momentum.
37-
warmup_iters: number of warmup iterations.
50+
momentum: the update momentum.
3851
3952
Note:
4053
The EMA model is already in ``eval`` mode. If model in the arguments is an ``nn.Module`` or
@@ -56,8 +69,7 @@ class EMAHandler:
5669
device = torch.device("cuda:0")
5770
model = nn.Linear(2, 1).to(device)
5871
# update the ema every 5 iterations
59-
ema_handler = EMAHandler(
60-
model, momentum=0.0002, momentum_warmup=0.0001, warmup_iters=10000)
72+
ema_handler = EMAHandler(model, momentum=0.0002)
6173
# get the ema model, which is an instance of nn.Module
6274
ema_model = ema_handler.ema_model
6375
trainer = Engine(train_step_fn)
@@ -89,6 +101,19 @@ def run_validation(engine):
89101
90102
trainer.run(...)
91103
104+
The following example shows how to perform warm-up to the EMA momentum:
105+
106+
.. code-block:: python
107+
108+
device = torch.device("cuda:0")
109+
model = nn.Linear(2, 1).to(device)
110+
# linearly change the EMA momentum from 0.2 to 0.002 in the first 100 iterations,
111+
# then keep a constant EMA momentum of 0.002 afterwards
112+
ema_handler = EMAHandler(model, momentum=0.002, momentum_warmup=0.2, warmup_iters=100)
113+
engine = Engine(step_fn)
114+
ema_handler.attach(engine, name="ema_momentum")
115+
engine.run(...)
116+
92117
The following example shows how to attach two handlers to the same trainer:
93118
94119
.. code-block:: python
@@ -125,25 +150,19 @@ def __init__(
125150
momentum_warmup: Optional[float] = None,
126151
warmup_iters: Optional[int] = None,
127152
) -> None:
128-
if momentum_warmup is not None and not 0 < momentum_warmup < 1:
129-
raise ValueError(f"Invalid momentum_warmup: {momentum_warmup}")
130153
if not 0 < momentum < 1:
131154
raise ValueError(f"Invalid momentum: {momentum}")
132-
if momentum_warmup is not None and not momentum_warmup <= momentum:
133-
raise ValueError(
134-
f"momentum_warmup should be less than or equal to momentum, but got "
135-
f"momentum_warmup: {momentum_warmup} and momentum: {momentum}"
136-
)
137-
if warmup_iters is not None and not (isinstance(warmup_iters, int) and warmup_iters > 0):
138-
raise ValueError(f"Invalid warmup_iters: {warmup_iters}")
155+
self.momentum = momentum
156+
self._momentum_lambda_obj: Optional[EMAWarmUp] = None
157+
if momentum_warmup is not None and warmup_iters is not None:
158+
self.momentum_scheduler: Optional[BaseParamScheduler] = None
159+
self._momentum_lambda_obj = EMAWarmUp(momentum_warmup, warmup_iters, momentum)
160+
139161
if not isinstance(model, nn.Module):
140162
raise ValueError(
141163
f"model should be an instance of nn.Module or its subclasses, but got"
142164
f"model: {model.__class__.__name__}"
143165
)
144-
self.momentum_warmup = momentum_warmup
145-
self.momentum = momentum
146-
self.warmup_iters = warmup_iters
147166

148167
if isinstance(model, nn.parallel.DistributedDataParallel):
149168
model = model.module
@@ -154,22 +173,6 @@ def __init__(
154173
param.detach_()
155174
self.ema_model.eval()
156175

157-
def _get_momentum(self, curr_iter: int) -> float:
158-
"""Get current momentum, `curr_iter` should be 1-based. When `curr_iter = 1`, `momentum =
159-
self.momentum_warmup`; when `curr_iter >= self.warmup_iters`, `momentum = self.momentum`"""
160-
161-
# TODO: use ignite's parameter scheduling, see also GitHub issue #2090
162-
if curr_iter < 1:
163-
raise ValueError(f"curr_iter should be at least 1, but got {curr_iter}.")
164-
165-
# no warmup
166-
if self.momentum_warmup is None or self.warmup_iters is None:
167-
return self.momentum
168-
169-
denominator = max(1, self.warmup_iters - 1)
170-
momentum = self.momentum_warmup + (self.momentum - self.momentum_warmup) * (curr_iter - 1) / denominator
171-
return min(self.momentum, momentum)
172-
173176
def _update_ema_model(self, engine: Engine, name: str) -> None:
174177
"""Update weights of ema model"""
175178
momentum = getattr(engine.state, name)
@@ -179,36 +182,47 @@ def _update_ema_model(self, engine: Engine, name: str) -> None:
179182
for ema_b, model_b in zip(self.ema_model.buffers(), self.model.buffers()):
180183
ema_b.data = model_b.data
181184

182-
def _update_ema_momentum(self, engine: Engine, name: str) -> None:
183-
"""Update momentum in engine.state"""
184-
curr_iter = engine.state.iteration
185-
momentum = self._get_momentum(curr_iter)
186-
setattr(engine.state, name, momentum)
187-
188185
def attach(
189186
self,
190187
engine: Engine,
191188
name: str = "ema_momentum",
189+
warn_if_exists: bool = True,
192190
event: Union[str, Events, CallableEventWithFilter, EventsList] = Events.ITERATION_COMPLETED,
193191
) -> None:
194192
"""Attach the handler to engine. After the handler is attached, the ``Engine.state`` will add an new attribute
195-
with name ``name``. Then, current momentum can be retrieved by from ``Engine.state`` when the engine runs.
193+
with name ``name`` if the attribute does not exist. Then, the current momentum can be retrieved from
194+
``Engine.state`` when the engine runs.
195+
196+
197+
Note:
198+
There are two cases where a momentum with name ``name`` already exists: 1. the engine has loaded its
199+
state dict after resuming. In this case, there is no need to initialize the momentum again, and users
200+
can set ``warn_if_exists`` to False to suppress the warning message; 2. another handler has created
201+
a state attribute with the same name. In this case, users should choose another name for the ema momentum.
202+
196203
197204
Args:
198205
engine: trainer to which the handler will be attached.
199206
name: attribute name for retrieving EMA momentum from ``Engine.state``. It should be a unique name since a
200207
trainer can have multiple EMA handlers.
208+
warn_if_exists: if True, a warning will be thrown if the momentum with name ``name`` already exists.
201209
event: event when the EMA momentum and EMA model are updated.
202210
203211
"""
204212
if hasattr(engine.state, name):
205-
raise ValueError(
206-
f"Attribute: '{name}' is already in Engine.state. Thus it might be "
207-
f"overridden by other EMA handlers. Please select another name."
208-
)
209-
210-
setattr(engine.state, name, 0.0)
211-
212-
# first update momentum, then update ema model
213-
engine.add_event_handler(event, self._update_ema_momentum, name)
213+
if warn_if_exists:
214+
warnings.warn(
215+
f"Attribute '{name}' already exists in Engine.state. It might because 1. the engine has loaded its "
216+
f"state dict or 2. {name} is already created by other handlers. Turn off this warning by setting"
217+
f"warn_if_exists to False.",
218+
category=UserWarning,
219+
)
220+
else:
221+
setattr(engine.state, name, self.momentum)
222+
223+
if self._momentum_lambda_obj is not None:
224+
self.momentum_scheduler = LambdaStateScheduler(self._momentum_lambda_obj, param_name="ema_momentum")
225+
226+
# first update the momentum and then update the EMA model
227+
self.momentum_scheduler.attach(engine, event)
214228
engine.add_event_handler(event, self._update_ema_model, name)

ignite/handlers/state_param_scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any, List, Sequence, Tuple, Union
55

66
from ignite.engine import CallableEventWithFilter, Engine, Events, EventsList
7-
from ignite.handlers import BaseParamScheduler
7+
from ignite.handlers.param_scheduler import BaseParamScheduler
88

99

1010
class StateParamScheduler(BaseParamScheduler):

tests/ignite/handlers/test_ema_handler.py

Lines changed: 53 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,51 @@ def test_ema_invalid_momentum(get_dummy_model, momentum):
4747
EMAHandler(get_dummy_model(), momentum=momentum)
4848

4949

50-
@pytest.mark.parametrize("momentum_warmup", [-1, 2])
51-
def test_ema_invalid_momentum_warmup(get_dummy_model, momentum_warmup):
52-
with pytest.raises(ValueError, match="Invalid momentum_warmup"):
53-
EMAHandler(get_dummy_model, momentum_warmup=momentum_warmup)
54-
50+
def test_has_momentum_scheduler(get_dummy_model):
51+
"""Test the handler has attribute `momentum_scheduler` and `_momentum_lambda_obj`"""
52+
momentum_warmup = 0.0
53+
warmup_iters = 10
54+
ema_handler = EMAHandler(get_dummy_model(), momentum_warmup=momentum_warmup, warmup_iters=warmup_iters)
55+
assert hasattr(ema_handler, "momentum_scheduler")
56+
assert hasattr(ema_handler, "_momentum_lambda_obj")
57+
58+
59+
def test_ema_warmup_func(get_dummy_model):
60+
"""Test the built-in linear warmup function for the EMA momentum"""
61+
momentum = 0.5
62+
momentum_warmup_1 = 0.0
63+
momentum_warmup_2 = 1.0
64+
warmup_iters = 5
65+
66+
def check_ema_momentum(engine: Engine, momentum_warmup, final_momentum, warmup_iters):
67+
if engine.state.iteration == 1:
68+
assert engine.state.ema_momentum == momentum_warmup
69+
elif engine.state.iteration >= 1 + warmup_iters:
70+
assert engine.state.ema_momentum == final_momentum
71+
else:
72+
min_momentum = min(momentum, momentum_warmup)
73+
max_momentum = max(momentum, momentum_warmup)
74+
assert min_momentum <= engine.state.ema_momentum <= max_momentum
5575

56-
def test_ema_invalid_momentum_start_end(get_dummy_model):
57-
"""Test momentum_end > momentum_start"""
58-
momentum = 0.001
59-
momentum_warmup = 0.1
60-
with pytest.raises(ValueError, match="momentum_warmup should be less than or equal to momentum"):
61-
EMAHandler(get_dummy_model(), momentum_warmup=momentum_warmup, momentum=momentum)
76+
# momentum_warmup < momentum
77+
model_1 = get_dummy_model()
78+
engine_1 = Engine(_get_dummy_step_fn(model_1))
79+
ema_handler_1 = EMAHandler(model_1, momentum, momentum_warmup_1, warmup_iters)
80+
ema_handler_1.attach(engine_1)
81+
engine_1.add_event_handler(
82+
Events.ITERATION_COMPLETED, check_ema_momentum, momentum_warmup_1, momentum, warmup_iters
83+
)
84+
engine_1.run(range(10))
85+
86+
# momentum_warmup > momentum
87+
model_2 = get_dummy_model()
88+
engine_2 = Engine(_get_dummy_step_fn(model_2))
89+
ema_handler_2 = EMAHandler(model_2, momentum, momentum_warmup_2, warmup_iters)
90+
ema_handler_2.attach(engine_2)
91+
engine_2.add_event_handler(
92+
Events.ITERATION_COMPLETED, check_ema_momentum, momentum_warmup_2, momentum, warmup_iters
93+
)
94+
engine_2.run(range(10))
6295

6396

6497
def test_ema_invalid_model():
@@ -98,54 +131,19 @@ def test_ema_load_state_dict(get_dummy_model):
98131
assert ema_model.weight.data.allclose(model_1.weight.data)
99132

100133

101-
def test_ema_no_warmup_momentum(get_dummy_model):
134+
def test_ema_get_const_momentum(get_dummy_model):
135+
"""Test if momentum retrieved from the engine is constant and equal to the handler's momentum"""
102136
model = get_dummy_model()
103137
step_fn = _get_dummy_step_fn(model)
104138
engine = Engine(step_fn)
105139

106140
def assert_const_momentum(engine: Engine, const_momentum):
107141
assert engine.state.ema_momentum == const_momentum
108142

109-
# no momentum_warmup
110-
ema_handler = EMAHandler(model, momentum=0.002, momentum_warmup=None, warmup_iters=1)
111-
ema_handler.attach(engine)
112-
# attach the assertion handler after ema_handler, so the momentum is first updated and then tested
113-
engine.add_event_handler(Events.ITERATION_COMPLETED, assert_const_momentum, ema_handler.momentum)
114-
engine.run(range(2))
115-
116-
# no warmup_iters
117-
engine = Engine(step_fn)
118-
ema_handler = EMAHandler(model, momentum=0.002, momentum_warmup=0.001, warmup_iters=None)
143+
ema_handler = EMAHandler(model, momentum=0.002)
119144
ema_handler.attach(engine)
120-
# attach the assertion handler after ema_handler, so the momentum is first updated and then tested
121145
engine.add_event_handler(Events.ITERATION_COMPLETED, assert_const_momentum, ema_handler.momentum)
122-
engine.run(range(2))
123-
124-
125-
def test_ema_update_ema_momentum(get_dummy_model):
126-
model = get_dummy_model()
127-
step_fn = _get_dummy_step_fn(model)
128-
engine = Engine(step_fn)
129-
130-
warmup_iters = 4
131-
momentum_warmup = 0.1
132-
momentum = 0.2
133-
ema_handler = EMAHandler(model, momentum_warmup=momentum_warmup, momentum=momentum, warmup_iters=warmup_iters)
134-
ema_handler.attach(engine)
135-
136-
# add handlers to check momentum at each iteration
137-
@engine.on(Events.ITERATION_COMPLETED)
138-
def assert_momentum(engine: Engine):
139-
curr_iter = engine.state.iteration
140-
curr_momentum = engine.state.ema_momentum
141-
if curr_iter == 1:
142-
assert curr_momentum == momentum_warmup
143-
elif 1 < curr_iter < warmup_iters:
144-
assert momentum_warmup < curr_momentum < momentum
145-
else:
146-
assert curr_momentum == momentum
147-
148-
engine.run(range(2), max_epochs=5)
146+
engine.run(range(10))
149147

150148

151149
def test_ema_buffer():
@@ -180,11 +178,10 @@ def check_buffers():
180178
def test_ema_two_handlers(get_dummy_model):
181179
"""Test when two EMA handlers are attached to a trainer"""
182180
model_1 = get_dummy_model()
183-
# momentum will be constantly 0.5
184-
ema_handler_1 = EMAHandler(model_1, momentum_warmup=0.5, momentum=0.5, warmup_iters=1)
181+
ema_handler_1 = EMAHandler(model_1, momentum=0.5)
185182

186183
model_2 = get_dummy_model()
187-
ema_handler_2 = EMAHandler(model_2, momentum_warmup=0.5, momentum=0.5, warmup_iters=1)
184+
ema_handler_2 = EMAHandler(model_2, momentum=0.5)
188185

189186
def _step_fn(engine: Engine, batch: Any):
190187
model_1.weight.data.add_(1)
@@ -214,8 +211,8 @@ def _step_fn(engine: Engine, batch: Any):
214211

215212
model_3 = get_dummy_model()
216213
ema_handler_3 = EMAHandler(model_3)
217-
with pytest.raises(ValueError, match="Please select another name"):
218-
ema_handler_3.attach(engine, "ema_momentum_2")
214+
with pytest.warns(UserWarning, match="Attribute 'ema_momentum_1' already exists"):
215+
ema_handler_3.attach(engine, name="ema_momentum_1")
219216

220217

221218
def _test_ema_final_weight(model, device=None, ddp=False, interval=1):
@@ -231,8 +228,7 @@ def _test_ema_final_weight(model, device=None, ddp=False, interval=1):
231228
step_fn = _get_dummy_step_fn(model)
232229
engine = Engine(step_fn)
233230

234-
# momentum will be constantly 0.5
235-
ema_handler = EMAHandler(model, momentum_warmup=0.5, momentum=0.5, warmup_iters=1)
231+
ema_handler = EMAHandler(model, momentum=0.5)
236232
ema_handler.attach(engine, "model", event=Events.ITERATION_COMPLETED(every=interval))
237233

238234
# engine will run 4 iterations

0 commit comments

Comments
 (0)