Skip to content

Commit a40c620

Browse files
authored
Merge pull request rushter#56 from Antetokounpo/adamax
Add Adamax optimizer and unit test
2 parents ba450f8 + fff1c28 commit a40c620

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

mla/neuralnet/optimizers.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,31 @@ def setup(self, network):
197197
for n in layer.parameters.keys():
198198
self.ms[i][n] = np.zeros_like(layer.parameters[n])
199199
self.vs[i][n] = np.zeros_like(layer.parameters[n])
200+
201+
class Adamax(Optimizer):
202+
def __init__(self, learning_rate=0.002, beta_1=0.9, beta_2=0.999, epsilon=1e-8):
203+
204+
self.epsilon = epsilon
205+
self.beta_2 = beta_2
206+
self.beta_1 = beta_1
207+
self.lr = learning_rate
208+
self.t = 1
209+
210+
def update(self, network):
211+
for i, layer in enumerate(network.parametric_layers):
212+
for n in layer.parameters.keys():
213+
grad = layer.parameters.grad[n]
214+
self.ms[i][n] = self.beta_1 * self.ms[i][n] + (1.0 - self.beta_1) * grad
215+
self.us[i][n] = np.maximum(self.beta_2 * self.us[i][n], np.abs(grad))
216+
217+
step = self.lr / (1 - self.beta_1 ** self.t) * self.ms[i][n]/(self.us[i][n] + self.epsilon)
218+
layer.parameters.step(n, -step)
219+
self.t += 1
220+
221+
def setup(self, network):
222+
self.ms = defaultdict(dict)
223+
self.us = defaultdict(dict)
224+
for i, layer in enumerate(network.parametric_layers):
225+
for n in layer.parameters.keys():
226+
self.ms[i][n] = np.zeros_like(layer.parameters[n])
227+
self.us[i][n] = np.zeros_like(layer.parameters[n])

mla/neuralnet/tests/test_optimizers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def test_adadelta():
4646
def test_adam():
4747
assert clasifier(Adam()) > 0.9
4848

49+
def test_adamax():
50+
assert clasifier(Adamax()) > 0.9
4951

5052
def test_rmsprop():
5153
assert clasifier(RMSprop()) > 0.9

0 commit comments

Comments
 (0)