From 89ec4ab926a793fc37a52dcc2be012c74ef18097 Mon Sep 17 00:00:00 2001 From: Xiaofang Wang Date: Fri, 19 Aug 2022 20:49:08 -0700 Subject: [PATCH] Add WarmupStepWithFixedGammaLR schedule Summary: Pull Request resolved: https://github.com/facebookresearch/detectron2/pull/4473 Add WarmupStepWithFixedGammaLR schedule Add cfg.SOLVER.RESCALE_INTERVAL to indicate whether we will rescale the interval of the scheduler after warmup in WarmupParamScheduler Reviewed By: newstzpz, wat3rBro Differential Revision: D38133633 fbshipit-source-id: 355ab44a274a26d5b99701aa9730f1c77ccae1fb --- detectron2/config/defaults.py | 2 ++ detectron2/solver/build.py | 14 ++++++++++- detectron2/solver/lr_scheduler.py | 5 +++- tests/test_scheduler.py | 41 ++++++++++++++++++++++++++++++- 4 files changed, 59 insertions(+), 3 deletions(-) diff --git a/detectron2/config/defaults.py b/detectron2/config/defaults.py index ea58627ef0..ececfad155 100644 --- a/detectron2/config/defaults.py +++ b/detectron2/config/defaults.py @@ -545,6 +545,8 @@ _C.SOLVER.WARMUP_FACTOR = 1.0 / 1000 _C.SOLVER.WARMUP_ITERS = 1000 _C.SOLVER.WARMUP_METHOD = "linear" +# Whether to rescale the interval for the learning schedule after warmup +_C.SOLVER.RESCALE_INTERVAL = False # Save a checkpoint after every this number of iterations _C.SOLVER.CHECKPOINT_PERIOD = 5000 diff --git a/detectron2/solver/build.py b/detectron2/solver/build.py index d79e23a60f..5e0f476358 100644 --- a/detectron2/solver/build.py +++ b/detectron2/solver/build.py @@ -6,7 +6,11 @@ from enum import Enum from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union import torch -from fvcore.common.param_scheduler import CosineParamScheduler, MultiStepParamScheduler +from fvcore.common.param_scheduler import ( + CosineParamScheduler, + MultiStepParamScheduler, + StepWithFixedGammaParamScheduler, +) from detectron2.config import CfgNode @@ -284,6 +288,13 @@ def build_lr_scheduler( end_value = cfg.SOLVER.BASE_LR_END / cfg.SOLVER.BASE_LR assert end_value >= 0.0 and end_value <= 1.0, end_value sched = CosineParamScheduler(1, end_value) + elif name == "WarmupStepWithFixedGammaLR": + sched = StepWithFixedGammaParamScheduler( + base_value=1.0, + gamma=cfg.SOLVER.GAMMA, + num_decays=cfg.SOLVER.NUM_DECAYS, + num_updates=cfg.SOLVER.MAX_ITER, + ) else: raise ValueError("Unknown LR scheduler: {}".format(name)) @@ -292,5 +303,6 @@ def build_lr_scheduler( cfg.SOLVER.WARMUP_FACTOR, min(cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_ITER, 1.0), cfg.SOLVER.WARMUP_METHOD, + cfg.SOLVER.RESCALE_INTERVAL, ) return LRMultiplier(optimizer, multiplier=sched, max_iter=cfg.SOLVER.MAX_ITER) diff --git a/detectron2/solver/lr_scheduler.py b/detectron2/solver/lr_scheduler.py index 8803e87b9e..bd2270afc5 100644 --- a/detectron2/solver/lr_scheduler.py +++ b/detectron2/solver/lr_scheduler.py @@ -25,6 +25,7 @@ def __init__( warmup_factor: float, warmup_length: float, warmup_method: str = "linear", + rescale_interval: bool = False, ): """ Args: @@ -33,6 +34,8 @@ def __init__( warmup_length: the relative length (in [0, 1]) of warmup steps w.r.t the entire training, e.g. 0.01 warmup_method: one of "linear" or "constant" + rescale_interval: whether we will rescale the interval of the scheduler after + warmup """ end_value = scheduler(warmup_length) # the value to reach when warmup ends start_value = warmup_factor * scheduler(0.0) @@ -44,7 +47,7 @@ def __init__( raise ValueError("Unknown warmup method: {}".format(warmup_method)) super().__init__( [warmup, scheduler], - interval_scaling=["rescaled", "fixed"], + interval_scaling=["rescaled", "rescaled" if rescale_interval else "fixed"], lengths=[warmup_length, 1 - warmup_length], ) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 51f14b0cb9..5649a4a2e1 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -4,7 +4,11 @@ import numpy as np from unittest import TestCase import torch -from fvcore.common.param_scheduler import CosineParamScheduler, MultiStepParamScheduler +from fvcore.common.param_scheduler import ( + CosineParamScheduler, + MultiStepParamScheduler, + StepWithFixedGammaParamScheduler, +) from torch import nn from detectron2.solver import LRMultiplier, WarmupParamScheduler, build_lr_scheduler @@ -117,3 +121,38 @@ def _test_end_value(cfg_dict): } } ) + + def test_warmup_stepwithfixedgamma(self): + p = nn.Parameter(torch.zeros(0)) + opt = torch.optim.SGD([p], lr=5) + + multiplier = WarmupParamScheduler( + StepWithFixedGammaParamScheduler( + base_value=1.0, + gamma=0.1, + num_decays=4, + num_updates=30, + ), + 0.001, + 5 / 30, + rescale_interval=True, + ) + sched = LRMultiplier(opt, multiplier, 30) + + p.sum().backward() + opt.step() + + lrs = [0.005] + for _ in range(29): + sched.step() + lrs.append(opt.param_groups[0]["lr"]) + self.assertTrue(np.allclose(lrs[:5], [0.005, 1.004, 2.003, 3.002, 4.001])) + self.assertTrue(np.allclose(lrs[5:10], 5.0)) + self.assertTrue(np.allclose(lrs[10:15], 0.5)) + self.assertTrue(np.allclose(lrs[15:20], 0.05)) + self.assertTrue(np.allclose(lrs[20:25], 0.005)) + self.assertTrue(np.allclose(lrs[25:], 0.0005)) + + # Calling sche.step() after the last training iteration is done will trigger IndexError + with self.assertRaises(IndexError, msg="list index out of range"): + sched.step()