-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathlr_schedule.py
59 lines (52 loc) · 1.85 KB
/
lr_schedule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from matplotlib import pyplot as plt
from torch import nn
from torch.optim import Adam
import pdb
class WarmupLinearLRSchedule:
"""
Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers.
"""
def __init__(self, optimizer, init_lr, peak_lr, end_lr, warmup_epochs, epochs=100, current_step=0):
self.init_lr = init_lr
self.peak_lr = peak_lr
self.optimizer = optimizer
self.warmup_rate = (peak_lr - init_lr) / warmup_epochs
self.decay_rate = (end_lr - peak_lr) / (epochs - warmup_epochs)
self.update_steps = current_step
self.lr = init_lr
self.warmup_steps = warmup_epochs
self.epochs = epochs
if current_step > 0:
self.lr = self.peak_lr + self.decay_rate * (current_step - 1 - warmup_epochs)
def set_lr(self, lr):
for g in self.optimizer.param_groups:
g['lr'] = lr
def step(self):
if self.update_steps <= self.warmup_steps:
lr = self.init_lr + self.warmup_rate * self.update_steps
else:
lr = max(0., self.lr + self.decay_rate)
self.set_lr(lr)
self.lr = lr
self.update_steps += 1
return self.lr
if __name__ == '__main__':
m = nn.Linear(10, 10)
opt = Adam(m.parameters(), lr=1e-4)
s = WarmupLinearLRSchedule(opt, 1e-6, 1e-4, 0., 2)
lrs = []
for i in range(101):
s.step()
lrs.append(s.lr)
print(s.lr)
m = nn.Linear(10, 10)
opt = Adam(m.parameters(), lr=1e-4)
s = WarmupLinearLRSchedule(opt, 1e-6, 1e-4, 0., 2, current_step=50)
lrs_s = []
for i in range(50, 101):
s.step()
lrs_s.append(s.lr)
print(s.lr)
plt.plot(lrs)
plt.plot(range(50, 101), lrs_s)
plt.show()