-
Notifications
You must be signed in to change notification settings - Fork 35
/
adabound.py
129 lines (103 loc) · 5.05 KB
/
adabound.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from keras import backend as K
from keras.optimizers import Optimizer
class AdaBound(Optimizer):
"""AdaBound optimizer.
Default parameters follow those provided in the original paper.
# Arguments
lr: float >= 0. Learning rate.
final_lr: float >= 0. Final learning rate.
beta_1: float, 0 < beta < 1. Generally close to 1.
beta_2: float, 0 < beta < 1. Generally close to 1.
gamma: float >= 0. Convergence speed of the bound function.
epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update.
weight_decay: Weight decay weight.
amsbound: boolean. Whether to apply the AMSBound variant of this
algorithm.
# References
- [Adaptive Gradient Methods with Dynamic Bound of Learning Rate]
(https://openreview.net/forum?id=Bkg3g2R9FX)
- [Adam - A Method for Stochastic Optimization]
(https://arxiv.org/abs/1412.6980v8)
- [On the Convergence of Adam and Beyond]
(https://openreview.net/forum?id=ryQu7f-RZ)
"""
def __init__(self, lr=0.001, final_lr=0.1, beta_1=0.9, beta_2=0.999, gamma=1e-3,
epsilon=None, decay=0., amsbound=False, weight_decay=0.0, **kwargs):
super(AdaBound, self).__init__(**kwargs)
if not 0. <= gamma <= 1.:
raise ValueError("Invalid `gamma` parameter. Must lie in [0, 1] range.")
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.lr = K.variable(lr, name='lr')
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
self.decay = K.variable(decay, name='decay')
self.final_lr = final_lr
self.gamma = gamma
if epsilon is None:
epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay
self.amsbound = amsbound
self.weight_decay = float(weight_decay)
self.base_lr = float(lr)
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]
lr = self.lr
if self.initial_decay > 0:
lr = lr * (1. / (1. + self.decay * K.cast(self.iterations,
K.dtype(self.decay))))
t = K.cast(self.iterations, K.floatx()) + 1
# Applies bounds on actual learning rate
step_size = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) /
(1. - K.pow(self.beta_1, t)))
final_lr = self.final_lr * lr / self.base_lr
lower_bound = final_lr * (1. - 1. / (self.gamma * t + 1.))
upper_bound = final_lr * (1. + 1. / (self.gamma * t))
ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
if self.amsbound:
vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
else:
vhats = [K.zeros(1) for _ in params]
self.weights = [self.iterations] + ms + vs + vhats
for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
# apply weight decay
if self.weight_decay != 0.:
g += self.weight_decay * K.stop_gradient(p)
m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
if self.amsbound:
vhat_t = K.maximum(vhat, v_t)
denom = (K.sqrt(vhat_t) + self.epsilon)
self.updates.append(K.update(vhat, vhat_t))
else:
denom = (K.sqrt(v_t) + self.epsilon)
# Compute the bounds
step_size_p = step_size * K.ones_like(denom)
step_size_p_bound = step_size_p / denom
bounded_lr_t = m_t * K.minimum(K.maximum(step_size_p_bound,
lower_bound), upper_bound)
p_t = p - bounded_lr_t
self.updates.append(K.update(m, m_t))
self.updates.append(K.update(v, v_t))
new_p = p_t
# Apply constraints.
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
self.updates.append(K.update(p, new_p))
return self.updates
def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'final_lr': float(self.final_lr),
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'gamma': float(self.gamma),
'decay': float(K.get_value(self.decay)),
'epsilon': self.epsilon,
'weight_decay': self.weight_decay,
'amsbound': self.amsbound}
base_config = super(AdaBound, self).get_config()
return dict(list(base_config.items()) + list(config.items()))