Skip to content

Commit 60620e4

Browse files
authored
Add end_lr in WarmupCosineSchedule (#6662)
Fixes #6527 . ### Description Add `end_lr` in `WarmupCosineSchedule` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <yunl@nvidia.com>
1 parent 8bc25b9 commit 60620e4

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

monai/optimizers/lr_scheduler.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
optimizer: Optimizer,
6969
warmup_steps: int,
7070
t_total: int,
71+
end_lr: float = 0.0,
7172
cycles: float = 0.5,
7273
last_epoch: int = -1,
7374
warmup_multiplier: float = 0,
@@ -77,6 +78,7 @@ def __init__(
7778
optimizer: wrapped optimizer.
7879
warmup_steps: number of warmup iterations.
7980
t_total: total number of training iterations.
81+
end_lr: the final learning rate. Defaults to 0.0.
8082
cycles: cosine cycles parameter.
8183
last_epoch: the index of last epoch.
8284
warmup_multiplier: if provided, starts the linear warmup from this fraction of the initial lr.
@@ -88,6 +90,7 @@ def __init__(
8890
self.warmup_multiplier = warmup_multiplier
8991
self.t_total = t_total
9092
self.cycles = cycles
93+
self.end_lr = end_lr
9194
if warmup_multiplier < 0 or warmup_multiplier > 1:
9295
raise ValueError("warmup_multiplier must be in 0..1 range")
9396
super().__init__(optimizer, self.lr_lambda, last_epoch)
@@ -98,3 +101,10 @@ def lr_lambda(self, step):
98101
return self.warmup_multiplier + (1 - self.warmup_multiplier) * f
99102
progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
100103
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
104+
105+
def get_lr(self):
106+
current_lr = [base_lr * lmbda(self.last_epoch) for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
107+
if self.last_epoch < self.warmup_steps:
108+
return current_lr
109+
else:
110+
return [max(self.end_lr, _current_lr) for _current_lr in current_lr]

tests/test_lr_scheduler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ def forward(self, x):
3535
{"warmup_steps": 2, "t_total": 10, "warmup_multiplier": 0.1},
3636
[0.1, 0.55, 1.00, 0.962, 0.854, 0.691, 0.500, 0.309, 0.146, 0.038],
3737
],
38+
[
39+
{"warmup_steps": 2, "t_total": 10, "warmup_multiplier": 0.1, "end_lr": 0.309},
40+
[0.1, 0.55, 1.00, 0.962, 0.854, 0.691, 0.500, 0.309, 0.309, 0.309],
41+
],
3842
]
3943

4044

0 commit comments

Comments
 (0)