Skip to content

Commit 169551e

Browse files
sihyeong671sadra-barikbinvfdev-5
authored
feat: Added warmup each cycle feature in CyclicalScheduler (#3064)
* feat: add feature each cycle in cosine annealing - add _get_cycle_param method in CyclicalScheduler - add warmup_each_cycle, warmup_duration variable in CyclicalScheduler - add warmup phase in CosineAnnealingScheduler issue #3036 * fix: f string - remove f in not using f string variable sentence * refactor: add _get_cycle_param, get_param in CyclicalScheduler - rename get_param in LinearCyclicalScheduler, CosineAnnealingScheduler to _get_clcye_param * fix: add total_cycle_size in _state_attrs * fix: add docstring, change function to abstractmethod - fix typo - add docsting which is in PR review - change _get_cycle_param to abstractmethod - raise ValueError when warmup_each_cycle=True && warmup_duration is None * feat: add test function - add test_cyclical_scheduler_asserts - add test_cosine_annealing_scheduler_warmup * fix: keep previous tag, fix test sentence * fix: type error match sentence * feat: refactor _get_cycle_param to _get_param - change _get_cycle_param to _get_param - add _get_param in ParamScheduler - remove warmup_each_cycle variable * docs: back docstring * feat: remove & fix - remove test_cyclical_scheduler - remove warmup_each_cycle variable * feat: remove first cycle warmup * feat: fix lrs value * A few improvements * Update param_scheduler.py * Update test_param_scheduler.py --------- Co-authored-by: Sadra Barikbin <sadraqazvin1@yahoo.com> Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent ea7cb1d commit 169551e

File tree

2 files changed

+79
-106
lines changed

2 files changed

+79
-106
lines changed

ignite/handlers/param_scheduler.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def __init__(
193193
self._state_attrs += ["param_group_index"]
194194

195195
def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None:
196-
value = self.get_param()
196+
value = self._get_param()
197197

198198
if isinstance(value, list):
199199
if len(value) != len(self.optimizer_param_groups):
@@ -261,6 +261,11 @@ def simulate_values(cls, num_events: int, **scheduler_kwargs: Any) -> List[List[
261261
values.append([i, scheduler.optimizer_param_groups[0][scheduler.param_name]])
262262
return values
263263

264+
def _get_param(self) -> Union[List[float], float]:
265+
# `ParamScheduler` does nothing special, only returning what child class returns.
266+
# Intermediate child classes edit this method
267+
return self.get_param()
268+
264269

265270
class CyclicalScheduler(ParamScheduler):
266271
"""An abstract class for updating an optimizer's parameter value over a
@@ -279,6 +284,9 @@ class CyclicalScheduler(ParamScheduler):
279284
end of each cycle (default=1.0).
280285
end_value_mult: ratio by which to change the end value at the
281286
end of each cycle (default=1.0).
287+
warmup_duration: duration of warm-up to be applied before each cycle.
288+
Through this warm-up, the parameter starts from the last cycle's end value
289+
and linearly goes to next cycle's start value. Default is no cyclic warm-up.
282290
save_history: whether to log the parameter values to
283291
`engine.state.param_history`, (default=False).
284292
param_group_index: optimizer's parameters group to use.
@@ -288,6 +296,9 @@ class CyclicalScheduler(ParamScheduler):
288296
usually be the number of batches in an epoch.
289297
290298
.. versionadded:: 0.4.5
299+
300+
.. versionchanged:: 0.4.13
301+
Added cyclic warm-up to the scheduler using ``warmup_duration``.
291302
"""
292303

293304
def __init__(
@@ -300,6 +311,7 @@ def __init__(
300311
cycle_mult: float = 1.0,
301312
start_value_mult: float = 1.0,
302313
end_value_mult: float = 1.0,
314+
warmup_duration: int = 0,
303315
save_history: bool = False,
304316
param_group_index: Optional[int] = None,
305317
):
@@ -308,11 +320,13 @@ def __init__(
308320
)
309321
self.start_value = start_value
310322
self.end_value = end_value
311-
self.cycle_size = int(cycle_size) # Ensure cycle_size is integer
323+
self.cycle_size = cycle_size
312324
self.cycle_mult = cycle_mult
313325
self.cycle = 0
314326
self.start_value_mult = start_value_mult
315327
self.end_value_mult = end_value_mult
328+
self.warmup_duration = warmup_duration
329+
self.total_cycle_size = self.warmup_duration + self.cycle_size
316330

317331
if self.cycle_size < 2:
318332
raise ValueError(f"Argument cycle_size should be positive and larger than 1, but given {cycle_size}")
@@ -325,18 +339,33 @@ def __init__(
325339
"cycle",
326340
"start_value_mult",
327341
"end_value_mult",
342+
"warmup_duration",
343+
"total_cycle_size",
328344
]
329345

330346
def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None:
331-
if self.event_index != 0 and self.event_index % self.cycle_size == 0:
347+
if self.event_index != 0 and self.event_index == self.cycle_size:
348+
self.start_value *= self.start_value_mult
349+
if self.event_index != 0 and self.event_index == self.total_cycle_size:
332350
self.event_index = 0
333351
self.cycle_size = int(self.cycle_size * self.cycle_mult)
352+
self.warmup_duration = int(self.warmup_duration * self.cycle_mult)
353+
self.total_cycle_size = self.warmup_duration + self.cycle_size
334354
self.cycle += 1
335-
self.start_value *= self.start_value_mult
336355
self.end_value *= self.end_value_mult
337356

338357
return super(CyclicalScheduler, self).__call__(engine, name)
339358

359+
def _get_param(self) -> Union[List[float], float]:
360+
"""Applies warm-up if the scheduler is in the warm-up phase,
361+
otherwise returns what is returned by `self.get_param()`
362+
"""
363+
if self.event_index > self.cycle_size:
364+
warmup_progress = (self.event_index - self.cycle_size) / self.warmup_duration
365+
return self.end_value + (self.start_value - self.end_value) * warmup_progress
366+
367+
return self.get_param()
368+
340369

341370
class LinearCyclicalScheduler(CyclicalScheduler):
342371
"""Linearly adjusts param value to 'end_value' for a half-cycle, then linearly
@@ -355,6 +384,9 @@ class LinearCyclicalScheduler(CyclicalScheduler):
355384
end of each cycle (default=1.0).
356385
end_value_mult: ratio by which to change the end value at the
357386
end of each cycle (default=1.0).
387+
warmup_duration: duration of warm-up to be applied before each cycle.
388+
Through this warm-up, the parameter starts from the last cycle's end value
389+
and linearly goes to next cycle's start value. Default is no cyclic warm-up.
358390
save_history: whether to log the parameter values to
359391
`engine.state.param_history`, (default=False).
360392
param_group_index: optimizer's parameters group to use.
@@ -430,9 +462,13 @@ def print_lr():
430462
...
431463
432464
.. versionadded:: 0.4.5
465+
466+
.. versionchanged:: 0.4.13
467+
Added cyclic warm-up to the scheduler using ``warmup_duration``.
433468
"""
434469

435470
def get_param(self) -> float:
471+
"""Method to get current optimizer's parameter value"""
436472
cycle_progress = self.event_index / self.cycle_size
437473
return self.end_value + (self.start_value - self.end_value) * abs(cycle_progress - 0.5) * 2
438474

@@ -456,6 +492,9 @@ class CosineAnnealingScheduler(CyclicalScheduler):
456492
end of each cycle (default=1.0).
457493
end_value_mult: ratio by which to change the end value at the
458494
end of each cycle (default=1.0).
495+
warmup_duration: duration of warm-up to be applied before each cycle.
496+
Through this warm-up, the parameter starts from the last cycle's end value
497+
and linearly goes to next cycle's start value. Default is no cyclic warm-up.
459498
save_history: whether to log the parameter values to
460499
`engine.state.param_history`, (default=False).
461500
param_group_index: optimizer's parameters group to use.
@@ -534,6 +573,9 @@ def print_lr():
534573
Applications of Computer Vision (WACV), 2017 IEEE Winter Conference on. IEEE, 2017
535574
536575
.. versionadded:: 0.4.5
576+
577+
.. versionchanged:: 0.4.13
578+
Added cyclic warm-up to the scheduler using ``warmup_duration``.
537579
"""
538580

539581
def get_param(self) -> float:

tests/ignite/handlers/test_param_scheduler.py

Lines changed: 33 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_param_scheduler_asserts():
5555
FakeParamScheduler({}, "lr")
5656

5757

58-
def test_linear_scheduler():
58+
def test_linear_scheduler_asserts():
5959
with pytest.raises(TypeError, match=r"Argument optimizer should be torch.optim.Optimizer"):
6060
LinearCyclicalScheduler({}, "lr", 1, 0, cycle_size=0)
6161

@@ -68,6 +68,11 @@ def test_linear_scheduler():
6868
with pytest.raises(ValueError, match=r"Argument cycle_size should be positive and larger than 1"):
6969
LinearCyclicalScheduler(optimizer, "lr", 1, 0, cycle_size=1)
7070

71+
72+
def test_linear_scheduler():
73+
tensor = torch.zeros([1], requires_grad=True)
74+
optimizer = torch.optim.SGD([tensor], lr=0.0)
75+
7176
scheduler = LinearCyclicalScheduler(optimizer, "lr", 1, 0, 10)
7277
state_dict = scheduler.state_dict()
7378

@@ -77,38 +82,12 @@ def save_lr(engine):
7782
trainer = Engine(lambda engine, batch: None)
7883
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
7984
trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)
80-
85+
lr_values_in_cycle = [1.0, 0.8, 0.6, 0.4, 0.2, 0.0, 0.2, 0.4, 0.6, 0.8]
8186
for _ in range(2):
8287
lrs = []
83-
trainer.run([0] * 9, max_epochs=2)
88+
trainer.run([0] * 10, max_epochs=2)
8489

85-
assert lrs == list(
86-
map(
87-
pytest.approx,
88-
[
89-
# Cycle 1
90-
1.0,
91-
0.8,
92-
0.6,
93-
0.4,
94-
0.2,
95-
0.0,
96-
0.2,
97-
0.4,
98-
0.6,
99-
0.8,
100-
# Cycle 2
101-
1.0,
102-
0.8,
103-
0.6,
104-
0.4,
105-
0.2,
106-
0.0,
107-
0.2,
108-
0.4, # 0.6, 0.8,
109-
],
110-
)
111-
)
90+
assert lrs == pytest.approx([*lr_values_in_cycle, *lr_values_in_cycle])
11291
scheduler.load_state_dict(state_dict)
11392

11493
optimizer = torch.optim.SGD([tensor], lr=0)
@@ -164,49 +143,6 @@ def save_lr(engine):
164143
)
165144
scheduler.load_state_dict(state_dict)
166145

167-
# With float cycle_size
168-
optimizer = torch.optim.SGD([tensor], lr=0)
169-
scheduler = LinearCyclicalScheduler(
170-
optimizer, "lr", start_value=1.2, end_value=0.2, cycle_size=10.00000012, cycle_mult=1.0
171-
)
172-
state_dict = scheduler.state_dict()
173-
174-
trainer = Engine(lambda engine, batch: None)
175-
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
176-
trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)
177-
178-
for _ in range(2):
179-
lrs = []
180-
trainer.run([0] * 9, max_epochs=2)
181-
assert lrs == list(
182-
map(
183-
pytest.approx,
184-
[
185-
# Cycle 1
186-
1.2,
187-
1.0,
188-
0.8,
189-
0.6,
190-
0.4,
191-
0.2,
192-
0.4,
193-
0.6,
194-
0.8,
195-
1.0,
196-
# Cycle 2
197-
1.2,
198-
1.0,
199-
0.8,
200-
0.6,
201-
0.4,
202-
0.2,
203-
0.4,
204-
0.6, # 0.8, 1.0,
205-
],
206-
)
207-
)
208-
scheduler.load_state_dict(state_dict)
209-
210146

211147
def test_linear_scheduler_cycle_size_two():
212148
tensor = torch.zeros([1], requires_grad=True)
@@ -239,17 +175,23 @@ def save_lr(engine):
239175
assert lrs == pytest.approx([v for i, v in simulated_values])
240176

241177

242-
def test_cosine_annealing_scheduler():
178+
@pytest.mark.parametrize("cyclic_warmup", [False, True])
179+
def test_cosine_annealing_scheduler(cyclic_warmup):
243180
tensor = torch.zeros([1], requires_grad=True)
244181
optimizer = torch.optim.SGD([tensor], lr=0)
245182

246-
scheduler = CosineAnnealingScheduler(optimizer, "lr", 0, 1, 10)
183+
scheduler = CosineAnnealingScheduler(optimizer, "lr", 0, 1, 10, warmup_duration=2 if cyclic_warmup else 0)
247184
state_dict = scheduler.state_dict()
248185

249-
data = [0] * 9
186+
data = [0] * (10 + int(cyclic_warmup))
250187
max_epochs = 2
251188
simulated_values = CosineAnnealingScheduler.simulate_values(
252-
num_events=len(data) * max_epochs, param_name="lr", start_value=0, end_value=1, cycle_size=10
189+
num_events=len(data) * max_epochs,
190+
param_name="lr",
191+
start_value=0,
192+
end_value=1,
193+
cycle_size=10,
194+
warmup_duration=2 if cyclic_warmup else 0,
253195
)
254196

255197
def save_lr(engine):
@@ -258,36 +200,25 @@ def save_lr(engine):
258200
trainer = Engine(lambda engine, batch: None)
259201
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
260202
trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)
203+
lr_values_in_cycle = [
204+
0.0,
205+
0.02447174185242318,
206+
0.09549150281252627,
207+
0.20610737385376332,
208+
0.3454915028125263,
209+
0.5,
210+
0.6545084971874737,
211+
0.7938926261462365,
212+
0.9045084971874737,
213+
0.9755282581475768,
214+
]
215+
lr_values_in_warmup = np.linspace(1.0, 0.0, 2 + 1)[:-1].tolist() if cyclic_warmup else []
261216

262217
for _ in range(2):
263218
lrs = []
264219
trainer.run(data, max_epochs=max_epochs)
265220

266-
assert lrs == list(
267-
map(
268-
pytest.approx,
269-
[
270-
0.0,
271-
0.02447174185242318,
272-
0.09549150281252627,
273-
0.20610737385376332,
274-
0.3454915028125263,
275-
0.5,
276-
0.6545084971874737,
277-
0.7938926261462365,
278-
0.9045084971874737,
279-
0.9755282581475768,
280-
0.0,
281-
0.02447174185242318,
282-
0.09549150281252627,
283-
0.20610737385376332,
284-
0.3454915028125263,
285-
0.5,
286-
0.6545084971874737,
287-
0.7938926261462365, # 0.9045084971874737, 0.9755282581475768
288-
],
289-
)
290-
)
221+
assert lrs == pytest.approx([*lr_values_in_cycle, *lr_values_in_warmup, *lr_values_in_cycle])
291222
scheduler.load_state_dict(state_dict)
292223

293224
assert lrs == pytest.approx([v for i, v in simulated_values])

0 commit comments

Comments
 (0)