Skip to content

Commit

Permalink
fix learning rage
Browse files Browse the repository at this point in the history
  • Loading branch information
convergence-lab committed Aug 26, 2019
1 parent 0318970 commit e18eb4b
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions novograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import math

class NovoGrad(optim.Optimizer):
def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0):
def __init__(self, params, grad_averaging=False, grad_ema=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(NovoGrad, self).__init__(params, defaults)
self._lr = lr
Expand All @@ -13,6 +13,7 @@ def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps
self._eps = eps
self._wd = weight_decay
self._grad_averaging = grad_averaging
self._grad_ema = grad_ema

self._momentum_initialized = False

Expand Down Expand Up @@ -45,13 +46,17 @@ def step(self, closure=None):
continue
state = self.state[p]
state['step'] += 1
step, v, m, grad_ema = state['step'], state['v'], state['m'], state['grad_ema']

step, v, m = state['step'], state['v'], state['m']
if self._grad_ema:
grad_ema = state['grad_ema']

grad = p.grad.data
g2 = torch.norm(grad)**2
grad_ema = g2 if grad_ema is None else grad_ema * \
self._beta2 + g2*(1. - self._beta2)
grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps)
if self._grad_ema:
grad_ema = g2 if grad_ema is None else grad_ema * \
self._beta2 + g2*(1. - self._beta2)
grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps)

if self._wd > 0.:
grad += self._wd*p
Expand All @@ -62,8 +67,10 @@ def step(self, closure=None):
m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd*p.data)
bias_correction1 = 1 - self._beta1 ** step
bias_correction2 = 1 - self._beta2 ** step
step_size = self._lr * math.sqrt(bias_correction2) / bias_correction1
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

state['v'], state['m'], state['grad_ema'] = v, m, grad_ema
state['v'], state['m'] = v, m
if self._grad_ema:
state['grad_ema'] = grad_ema
p.data.add_(-step_size, m)
return loss

0 comments on commit e18eb4b

Please sign in to comment.