Skip to content

Commit

Permalink
Add WarmupStepWithFixedGammaLR schedule
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#4473

Add WarmupStepWithFixedGammaLR schedule

Add cfg.SOLVER.RESCALE_INTERVAL to indicate whether we will rescale the interval of the scheduler after warmup in WarmupParamScheduler

Reviewed By: newstzpz, wat3rBro

Differential Revision: D38133633

fbshipit-source-id: 355ab44a274a26d5b99701aa9730f1c77ccae1fb
  • Loading branch information
Xiaofang Wang authored and facebook-github-bot committed Aug 20, 2022
1 parent 36a65a0 commit 89ec4ab
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 3 deletions.
2 changes: 2 additions & 0 deletions detectron2/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,8 @@
_C.SOLVER.WARMUP_FACTOR = 1.0 / 1000
_C.SOLVER.WARMUP_ITERS = 1000
_C.SOLVER.WARMUP_METHOD = "linear"
# Whether to rescale the interval for the learning schedule after warmup
_C.SOLVER.RESCALE_INTERVAL = False

# Save a checkpoint after every this number of iterations
_C.SOLVER.CHECKPOINT_PERIOD = 5000
Expand Down
14 changes: 13 additions & 1 deletion detectron2/solver/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from enum import Enum
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union
import torch
from fvcore.common.param_scheduler import CosineParamScheduler, MultiStepParamScheduler
from fvcore.common.param_scheduler import (
CosineParamScheduler,
MultiStepParamScheduler,
StepWithFixedGammaParamScheduler,
)

from detectron2.config import CfgNode

Expand Down Expand Up @@ -284,6 +288,13 @@ def build_lr_scheduler(
end_value = cfg.SOLVER.BASE_LR_END / cfg.SOLVER.BASE_LR
assert end_value >= 0.0 and end_value <= 1.0, end_value
sched = CosineParamScheduler(1, end_value)
elif name == "WarmupStepWithFixedGammaLR":
sched = StepWithFixedGammaParamScheduler(
base_value=1.0,
gamma=cfg.SOLVER.GAMMA,
num_decays=cfg.SOLVER.NUM_DECAYS,
num_updates=cfg.SOLVER.MAX_ITER,
)
else:
raise ValueError("Unknown LR scheduler: {}".format(name))

Expand All @@ -292,5 +303,6 @@ def build_lr_scheduler(
cfg.SOLVER.WARMUP_FACTOR,
min(cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_ITER, 1.0),
cfg.SOLVER.WARMUP_METHOD,
cfg.SOLVER.RESCALE_INTERVAL,
)
return LRMultiplier(optimizer, multiplier=sched, max_iter=cfg.SOLVER.MAX_ITER)
5 changes: 4 additions & 1 deletion detectron2/solver/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
warmup_factor: float,
warmup_length: float,
warmup_method: str = "linear",
rescale_interval: bool = False,
):
"""
Args:
Expand All @@ -33,6 +34,8 @@ def __init__(
warmup_length: the relative length (in [0, 1]) of warmup steps w.r.t the entire
training, e.g. 0.01
warmup_method: one of "linear" or "constant"
rescale_interval: whether we will rescale the interval of the scheduler after
warmup
"""
end_value = scheduler(warmup_length) # the value to reach when warmup ends
start_value = warmup_factor * scheduler(0.0)
Expand All @@ -44,7 +47,7 @@ def __init__(
raise ValueError("Unknown warmup method: {}".format(warmup_method))
super().__init__(
[warmup, scheduler],
interval_scaling=["rescaled", "fixed"],
interval_scaling=["rescaled", "rescaled" if rescale_interval else "fixed"],
lengths=[warmup_length, 1 - warmup_length],
)

Expand Down
41 changes: 40 additions & 1 deletion tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import numpy as np
from unittest import TestCase
import torch
from fvcore.common.param_scheduler import CosineParamScheduler, MultiStepParamScheduler
from fvcore.common.param_scheduler import (
CosineParamScheduler,
MultiStepParamScheduler,
StepWithFixedGammaParamScheduler,
)
from torch import nn

from detectron2.solver import LRMultiplier, WarmupParamScheduler, build_lr_scheduler
Expand Down Expand Up @@ -117,3 +121,38 @@ def _test_end_value(cfg_dict):
}
}
)

def test_warmup_stepwithfixedgamma(self):
p = nn.Parameter(torch.zeros(0))
opt = torch.optim.SGD([p], lr=5)

multiplier = WarmupParamScheduler(
StepWithFixedGammaParamScheduler(
base_value=1.0,
gamma=0.1,
num_decays=4,
num_updates=30,
),
0.001,
5 / 30,
rescale_interval=True,
)
sched = LRMultiplier(opt, multiplier, 30)

p.sum().backward()
opt.step()

lrs = [0.005]
for _ in range(29):
sched.step()
lrs.append(opt.param_groups[0]["lr"])
self.assertTrue(np.allclose(lrs[:5], [0.005, 1.004, 2.003, 3.002, 4.001]))
self.assertTrue(np.allclose(lrs[5:10], 5.0))
self.assertTrue(np.allclose(lrs[10:15], 0.5))
self.assertTrue(np.allclose(lrs[15:20], 0.05))
self.assertTrue(np.allclose(lrs[20:25], 0.005))
self.assertTrue(np.allclose(lrs[25:], 0.0005))

# Calling sche.step() after the last training iteration is done will trigger IndexError
with self.assertRaises(IndexError, msg="list index out of range"):
sched.step()

0 comments on commit 89ec4ab

Please sign in to comment.