@@ -193,7 +193,7 @@ def __init__(
193
193
self ._state_attrs += ["param_group_index" ]
194
194
195
195
def __call__ (self , engine : Optional [Engine ], name : Optional [str ] = None ) -> None :
196
- value = self .get_param ()
196
+ value = self ._get_param ()
197
197
198
198
if isinstance (value , list ):
199
199
if len (value ) != len (self .optimizer_param_groups ):
@@ -261,6 +261,11 @@ def simulate_values(cls, num_events: int, **scheduler_kwargs: Any) -> List[List[
261
261
values .append ([i , scheduler .optimizer_param_groups [0 ][scheduler .param_name ]])
262
262
return values
263
263
264
+ def _get_param (self ) -> Union [List [float ], float ]:
265
+ # `ParamScheduler` does nothing special, only returning what child class returns.
266
+ # Intermediate child classes edit this method
267
+ return self .get_param ()
268
+
264
269
265
270
class CyclicalScheduler (ParamScheduler ):
266
271
"""An abstract class for updating an optimizer's parameter value over a
@@ -279,6 +284,9 @@ class CyclicalScheduler(ParamScheduler):
279
284
end of each cycle (default=1.0).
280
285
end_value_mult: ratio by which to change the end value at the
281
286
end of each cycle (default=1.0).
287
+ warmup_duration: duration of warm-up to be applied before each cycle.
288
+ Through this warm-up, the parameter starts from the last cycle's end value
289
+ and linearly goes to next cycle's start value. Default is no cyclic warm-up.
282
290
save_history: whether to log the parameter values to
283
291
`engine.state.param_history`, (default=False).
284
292
param_group_index: optimizer's parameters group to use.
@@ -288,6 +296,9 @@ class CyclicalScheduler(ParamScheduler):
288
296
usually be the number of batches in an epoch.
289
297
290
298
.. versionadded:: 0.4.5
299
+
300
+ .. versionchanged:: 0.4.13
301
+ Added cyclic warm-up to the scheduler using ``warmup_duration``.
291
302
"""
292
303
293
304
def __init__ (
@@ -300,6 +311,7 @@ def __init__(
300
311
cycle_mult : float = 1.0 ,
301
312
start_value_mult : float = 1.0 ,
302
313
end_value_mult : float = 1.0 ,
314
+ warmup_duration : int = 0 ,
303
315
save_history : bool = False ,
304
316
param_group_index : Optional [int ] = None ,
305
317
):
@@ -308,11 +320,13 @@ def __init__(
308
320
)
309
321
self .start_value = start_value
310
322
self .end_value = end_value
311
- self .cycle_size = int ( cycle_size ) # Ensure cycle_size is integer
323
+ self .cycle_size = cycle_size
312
324
self .cycle_mult = cycle_mult
313
325
self .cycle = 0
314
326
self .start_value_mult = start_value_mult
315
327
self .end_value_mult = end_value_mult
328
+ self .warmup_duration = warmup_duration
329
+ self .total_cycle_size = self .warmup_duration + self .cycle_size
316
330
317
331
if self .cycle_size < 2 :
318
332
raise ValueError (f"Argument cycle_size should be positive and larger than 1, but given { cycle_size } " )
@@ -325,18 +339,33 @@ def __init__(
325
339
"cycle" ,
326
340
"start_value_mult" ,
327
341
"end_value_mult" ,
342
+ "warmup_duration" ,
343
+ "total_cycle_size" ,
328
344
]
329
345
330
346
def __call__ (self , engine : Optional [Engine ], name : Optional [str ] = None ) -> None :
331
- if self .event_index != 0 and self .event_index % self .cycle_size == 0 :
347
+ if self .event_index != 0 and self .event_index == self .cycle_size :
348
+ self .start_value *= self .start_value_mult
349
+ if self .event_index != 0 and self .event_index == self .total_cycle_size :
332
350
self .event_index = 0
333
351
self .cycle_size = int (self .cycle_size * self .cycle_mult )
352
+ self .warmup_duration = int (self .warmup_duration * self .cycle_mult )
353
+ self .total_cycle_size = self .warmup_duration + self .cycle_size
334
354
self .cycle += 1
335
- self .start_value *= self .start_value_mult
336
355
self .end_value *= self .end_value_mult
337
356
338
357
return super (CyclicalScheduler , self ).__call__ (engine , name )
339
358
359
+ def _get_param (self ) -> Union [List [float ], float ]:
360
+ """Applies warm-up if the scheduler is in the warm-up phase,
361
+ otherwise returns what is returned by `self.get_param()`
362
+ """
363
+ if self .event_index > self .cycle_size :
364
+ warmup_progress = (self .event_index - self .cycle_size ) / self .warmup_duration
365
+ return self .end_value + (self .start_value - self .end_value ) * warmup_progress
366
+
367
+ return self .get_param ()
368
+
340
369
341
370
class LinearCyclicalScheduler (CyclicalScheduler ):
342
371
"""Linearly adjusts param value to 'end_value' for a half-cycle, then linearly
@@ -355,6 +384,9 @@ class LinearCyclicalScheduler(CyclicalScheduler):
355
384
end of each cycle (default=1.0).
356
385
end_value_mult: ratio by which to change the end value at the
357
386
end of each cycle (default=1.0).
387
+ warmup_duration: duration of warm-up to be applied before each cycle.
388
+ Through this warm-up, the parameter starts from the last cycle's end value
389
+ and linearly goes to next cycle's start value. Default is no cyclic warm-up.
358
390
save_history: whether to log the parameter values to
359
391
`engine.state.param_history`, (default=False).
360
392
param_group_index: optimizer's parameters group to use.
@@ -430,9 +462,13 @@ def print_lr():
430
462
...
431
463
432
464
.. versionadded:: 0.4.5
465
+
466
+ .. versionchanged:: 0.4.13
467
+ Added cyclic warm-up to the scheduler using ``warmup_duration``.
433
468
"""
434
469
435
470
def get_param (self ) -> float :
471
+ """Method to get current optimizer's parameter value"""
436
472
cycle_progress = self .event_index / self .cycle_size
437
473
return self .end_value + (self .start_value - self .end_value ) * abs (cycle_progress - 0.5 ) * 2
438
474
@@ -456,6 +492,9 @@ class CosineAnnealingScheduler(CyclicalScheduler):
456
492
end of each cycle (default=1.0).
457
493
end_value_mult: ratio by which to change the end value at the
458
494
end of each cycle (default=1.0).
495
+ warmup_duration: duration of warm-up to be applied before each cycle.
496
+ Through this warm-up, the parameter starts from the last cycle's end value
497
+ and linearly goes to next cycle's start value. Default is no cyclic warm-up.
459
498
save_history: whether to log the parameter values to
460
499
`engine.state.param_history`, (default=False).
461
500
param_group_index: optimizer's parameters group to use.
@@ -534,6 +573,9 @@ def print_lr():
534
573
Applications of Computer Vision (WACV), 2017 IEEE Winter Conference on. IEEE, 2017
535
574
536
575
.. versionadded:: 0.4.5
576
+
577
+ .. versionchanged:: 0.4.13
578
+ Added cyclic warm-up to the scheduler using ``warmup_duration``.
537
579
"""
538
580
539
581
def get_param (self ) -> float :
0 commit comments