-
Notifications
You must be signed in to change notification settings - Fork 49
/
adahessian.py
206 lines (157 loc) · 7.79 KB
/
adahessian.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import math
import types
import torch
import torch.optim
import torch.distributed as dist
from copy import deepcopy
import numpy as np
from . import FairseqOptimizer, register_optimizer
@register_optimizer('adahessian')
class FairseqAdahess(FairseqOptimizer):
"""Adam optimizer for fairseq.
Important note: this optimizer corresponds to the "AdamW" variant of
Adam in its weight decay behavior. As such, it is most closely
analogous to torch.optim.AdamW from PyTorch.
"""
def __init__(self, args, params):
super().__init__(args)
self._optimizer = Adahess(params, **self.optimizer_config)
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B',
help='betas for Adam optimizer')
parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D',
help='epsilon for Adam optimizer')
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
parser.add_argument('--block-length', default=1, type=int,
help='We use this number for length of the hessian average block')
parser.add_argument('--hessian-power', type=float, default=1, metavar='H',
help='Hessian power')
# fmt: on
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return {
'lr': self.args.lr[0],
'betas': eval(self.args.adam_betas),
'eps': self.args.adam_eps,
'weight_decay': self.args.weight_decay,
'block_length': self.args.block_length,
'single_gpu': self.args.single_gpu,
'hessian_power': self.args.hessian_power
}
def average_params(self):
"""Reduce Params is only used during BMUF distributed training."""
state_dict = self.optimizer.state_dict()
total_gpus = float(dist.get_world_size())
for _, value in state_dict["state"].items():
value["exp_avg"] /= total_gpus
value["exp_avg_sq"] /= total_gpus
dist.all_reduce(value["exp_avg"], op=dist.ReduceOp.SUM)
dist.all_reduce(value["exp_avg_sq"], op=dist.ReduceOp.SUM)
class Adahess(torch.optim.Optimizer):
"""Implements AdamHess algorithm.
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, block_length=1, hessian_power=1, single_gpu=False):
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay)
super(Adahess, self).__init__(params, defaults)
self.block_length = block_length
self.single_gpu = single_gpu
self.hessian_power = hessian_power
def get_trace(self, gradsH):
"""
compute the Hessian vector product with v, at the current gradient point.
or compute the gradient of <gradsH,v>.
:param v: a list of torch tensors
:param gradsH: a list of torch variables
:return: a list of torch tensors
"""
params = self.param_groups[0]['params']
params = list(filter(lambda x: x.requires_grad, params) )
v = [ 2 * torch.randint_like(p, high=2, device='cuda') - 1 for p in params]
# this is for distributed setting with single node and multi-gpus,
# for multi nodes setting, we have not support it yet.
if not self.single_gpu:
for v1 in v:
dist.all_reduce(v1)
if not self.single_gpu:
for v_i in v:
v_i[v_i < 0.] = -1.
v_i[v_i >= 0.] = 1.
hvs = torch.autograd.grad(gradsH, params, grad_outputs=v, only_inputs=True, retain_graph=True)
hutchinson_trace = []
for hv, vi in zip(hvs, v):
param_size = hv.size()
if len(param_size) <= 1: # for Bias and LN
tmp_output = torch.abs( hv * vi) + 0.
hutchinson_trace.append( tmp_output )
elif len(param_size) == 2: # Matrix
tmp_output1 = torch.abs((hv * vi + 0.)).view(-1, self.block_length) # faltten to the N times self.block_length
tmp_output2 = torch.abs(torch.sum(tmp_output1, dim=[1])).view(-1) / float(self.block_length)
tmp_output3 = tmp_output2.repeat_interleave(self.block_length).view(param_size)
hutchinson_trace.append(tmp_output3)
# this is for distributed setting with single node and multi-gpus,
# for multi nodes setting, we have not support it yet.
if not self.single_gpu:
for output1 in hutchinson_trace:
dist.all_reduce(output1 / torch.cuda.device_count())
return hutchinson_trace
def step(self, gradsH=None, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
hut_trace = self.get_trace(gradsH)
for group in self.param_groups:
for i, p in enumerate(group['params']):
if p.grad is None:
continue
# grad = p.grad.data.float()
grad = deepcopy(gradsH[i].data.float())
if grad.is_sparse:
raise RuntimeError('AdaHessian does not support sparse gradients, please consider SparseAdam instead')
p_data_fp32 = p.data.float()
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p_data_fp32)
# Exponential moving average of squared gradient values
state['exp_hessian_diag_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_hessian_diag_sq'] = state['exp_hessian_diag_sq'].type_as(p_data_fp32)
exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_hessian_diag_sq.mul_(beta2).addcmul_(1 - beta2, hut_trace[i] , hut_trace[i])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
if self.hessian_power < 1:
denom = ((exp_hessian_diag_sq.sqrt() / math.sqrt(bias_correction2)) ** self.hessian_power).add_(group['eps'])
else:
denom = (exp_hessian_diag_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
step_size = group['lr'] / bias_correction1
# do weight decay
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
p.data.copy_(p_data_fp32)
return loss