Skip to content

Commit

Permalink
Add WarmupStepWithFixedGammaLR schedule
Browse files Browse the repository at this point in the history
Summary:
Add WarmupStepWithFixedGammaLR schedule

Add cfg.SOLVER.INTERVAL_RESCALED to indicate whether we will rescale the interval of the scheduler after warming up in WarmupParamScheduler

Differential Revision: D38133633

fbshipit-source-id: a7478bd76d9d40104684e03834d4c688d62c351d
  • Loading branch information
Xiaofang Wang authored and facebook-github-bot committed Aug 11, 2022
1 parent 5aeb252 commit da9ed8b
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 6 deletions.
1 change: 1 addition & 0 deletions detectron2/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@
_C.SOLVER.WARMUP_FACTOR = 1.0 / 1000
_C.SOLVER.WARMUP_ITERS = 1000
_C.SOLVER.WARMUP_METHOD = "linear"
_C.SOLVER.INTERVAL_RESCALED = False

# Save a checkpoint after every this number of iterations
_C.SOLVER.CHECKPOINT_PERIOD = 5000
Expand Down
14 changes: 13 additions & 1 deletion detectron2/solver/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from enum import Enum
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union
import torch
from fvcore.common.param_scheduler import CosineParamScheduler, MultiStepParamScheduler
from fvcore.common.param_scheduler import (
CosineParamScheduler,
MultiStepParamScheduler,
StepWithFixedGammaParamScheduler,
)

from detectron2.config import CfgNode

Expand Down Expand Up @@ -284,6 +288,13 @@ def build_lr_scheduler(
end_value = cfg.SOLVER.BASE_LR_END / cfg.SOLVER.BASE_LR
assert end_value >= 0.0 and end_value <= 1.0, end_value
sched = CosineParamScheduler(1, end_value)
elif name == "WarmupStepWithFixedGammaLR":
sched = StepWithFixedGammaParamScheduler(
base_value=1.0,
gamma=cfg.SOLVER.GAMMA,
num_decays=cfg.SOLVER.NUM_DECAYS,
num_updates=cfg.SOLVER.MAX_ITER,
)
else:
raise ValueError("Unknown LR scheduler: {}".format(name))

Expand All @@ -292,5 +303,6 @@ def build_lr_scheduler(
cfg.SOLVER.WARMUP_FACTOR,
min(cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_ITER, 1.0),
cfg.SOLVER.WARMUP_METHOD,
cfg.SOLVER.INTERVAL_RESCALED,
)
return LRMultiplier(optimizer, multiplier=sched, max_iter=cfg.SOLVER.MAX_ITER)
7 changes: 6 additions & 1 deletion detectron2/solver/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
warmup_factor: float,
warmup_length: float,
warmup_method: str = "linear",
interval_rescaled: bool = False,
):
"""
Args:
Expand All @@ -33,6 +34,8 @@ def __init__(
warmup_length: the relative length (in [0, 1]) of warmup steps w.r.t the entire
training, e.g. 0.01
warmup_method: one of "linear" or "constant"
interval_rescaled: whether we will rescale the interval of the scheduler after
warming up
"""
end_value = scheduler(warmup_length) # the value to reach when warmup ends
start_value = warmup_factor * scheduler(0.0)
Expand All @@ -44,7 +47,9 @@ def __init__(
raise ValueError("Unknown warmup method: {}".format(warmup_method))
super().__init__(
[warmup, scheduler],
interval_scaling=["rescaled", "fixed"],
interval_scaling=["rescaled", "rescaled"]
if interval_rescaled
else ["rescaled", "fixed"],
lengths=[warmup_length, 1 - warmup_length],
)

Expand Down
44 changes: 40 additions & 4 deletions tests/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
# Copyright (c) Facebook, Inc. and its affiliates.

import math
import numpy as np
from unittest import TestCase

import numpy as np
import torch
from fvcore.common.param_scheduler import CosineParamScheduler, MultiStepParamScheduler
from torch import nn

from detectron2.solver import LRMultiplier, WarmupParamScheduler, build_lr_scheduler
from detectron2.solver import build_lr_scheduler, LRMultiplier, WarmupParamScheduler
from fvcore.common.param_scheduler import (
CosineParamScheduler,
MultiStepParamScheduler,
StepWithFixedGammaParamScheduler,
)
from torch import nn


class TestScheduler(TestCase):
Expand Down Expand Up @@ -117,3 +122,34 @@ def _test_end_value(cfg_dict):
}
}
)

def test_warmup_stepwithfixedgamma(self):
p = nn.Parameter(torch.zeros(0))
opt = torch.optim.SGD([p], lr=5)

multiplier = WarmupParamScheduler(
StepWithFixedGammaParamScheduler(
base_value=1.0,
gamma=0.1,
num_decays=4,
num_updates=30,
),
0.001,
5 / 30,
interval_rescaled=True,
)
sched = LRMultiplier(opt, multiplier, 30)

p.sum().backward()
opt.step()

lrs = [0.005]
for _ in range(29):
sched.step()
lrs.append(opt.param_groups[0]["lr"])
self.assertTrue(np.allclose(lrs[:5], [0.005, 1.004, 2.003, 3.002, 4.001]))
self.assertTrue(np.allclose(lrs[5:10], 5.0))
self.assertTrue(np.allclose(lrs[10:15], 0.5))
self.assertTrue(np.allclose(lrs[15:20], 0.05))
self.assertTrue(np.allclose(lrs[20:25], 0.005))
self.assertTrue(np.allclose(lrs[25:], 0.0005))

0 comments on commit da9ed8b

Please sign in to comment.