-
Notifications
You must be signed in to change notification settings - Fork 4
/
hessian.py
270 lines (229 loc) · 9.79 KB
/
hessian.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
#*
# @file Different utility functions
# Copyright (c) Zhewei Yao, Amir Gholami
# All rights reserved.
# This file is part of PyHessian library.
#
# PyHessian is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# PyHessian is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with PyHessian. If not, see <http://www.gnu.org/licenses/>.
#*
import torch
import math
from torch.autograd import Variable
import numpy as np
from pyhessian.utils import group_product, group_add, normalization, get_params_grad, hessian_vector_product, orthnormal
class hessian():
"""
The class used to compute :
i) the top 1 (n) eigenvalue(s) of the neural network
ii) the trace of the entire neural network
iii) the estimated eigenvalue density
"""
def __init__(self, model, criterion, data=None, dataloader=None, cuda=True):
"""
model: the model that needs Hessain information
criterion: the loss function
data: a single batch of data, including inputs and its corresponding labels
dataloader: the data loader including bunch of batches of data
"""
# make sure we either pass a single batch or a dataloader
assert (data != None and dataloader == None) or (data == None and
dataloader != None)
self.model = model.eval() # make model is in evaluation model
self.criterion = criterion
if data != None:
self.data = data
self.full_dataset = False
else:
self.data = dataloader
self.full_dataset = True
if cuda:
self.device = 'cuda'
else:
self.device = 'cpu'
# pre-processing for single batch case to simplify the computation.
if not self.full_dataset:
self.inputs, self.targets = self.data
if self.device == 'cuda':
self.inputs, self.targets = self.inputs.cuda(
), self.targets.cuda()
# if we only compute the Hessian information for a single batch data, we can re-use the gradients.
outputs= self.model(self.inputs)
loss = self.criterion(outputs, self.targets)
loss.backward(create_graph=True)
# this step is used to extract the parameters from the model
params, gradsH = get_params_grad(self.model)
self.params = params
self.gradsH = gradsH # gradient used for Hessian computation
def dataloader_hv_product(self, v):
device = self.device
num_data = 0 # count the number of datum points in the dataloader
THv = [torch.zeros(p.size()).to(device) for p in self.params
] # accumulate result
for inputs, targets in self.data:
self.model.zero_grad()
tmp_num_data = inputs.size(0)
outputs= self.model(inputs.to(device))
loss = self.criterion(outputs, targets.to(device))
loss.backward(create_graph=True)
params, gradsH = get_params_grad(self.model)
self.model.zero_grad()
Hv = torch.autograd.grad(gradsH,
params,
grad_outputs=v,
only_inputs=True,
retain_graph=False)
THv = [
THv1 + Hv1 * float(tmp_num_data) + 0.
for THv1, Hv1 in zip(THv, Hv)
]
num_data += float(tmp_num_data)
THv = [THv1 / float(num_data) for THv1 in THv]
eigenvalue = group_product(THv, v).cpu().item()
return eigenvalue, THv
def eigenvalues(self, maxIter=100, tol=1e-3, top_n=1):
"""
compute the top_n eigenvalues using power iteration method
maxIter: maximum iterations used to compute each single eigenvalue
tol: the relative tolerance between two consecutive eigenvalue computations from power iteration
top_n: top top_n eigenvalues will be computed
"""
assert top_n >= 1
device = self.device
eigenvalues = []
eigenvectors = []
computed_dim = 0
while computed_dim < top_n:
eigenvalue = None
v = [torch.randn(p.size()).to(device) for p in self.params
] # generate random vector
v = normalization(v) # normalize the vector
for i in range(maxIter):
v = orthnormal(v, eigenvectors)
self.model.zero_grad()
if self.full_dataset:
tmp_eigenvalue, Hv = self.dataloader_hv_product(v)
else:
Hv = hessian_vector_product(self.gradsH, self.params, v)
tmp_eigenvalue = group_product(Hv, v).cpu().item()
v = normalization(Hv)
if eigenvalue == None:
eigenvalue = tmp_eigenvalue
else:
if abs(eigenvalue - tmp_eigenvalue) / (abs(eigenvalue) +
1e-6) < tol:
break
else:
eigenvalue = tmp_eigenvalue
eigenvalues.append(eigenvalue)
eigenvectors.append(v)
computed_dim += 1
return eigenvalues, eigenvectors
def trace(self, maxIter=100, tol=1e-3):
"""
compute the trace of hessian using Hutchinson's method
maxIter: maximum iterations used to compute trace
tol: the relative tolerance
"""
device = self.device
trace_vhv = []
trace = 0.
for i in range(maxIter):
self.model.zero_grad()
v = [
torch.randint_like(p, high=2, device=device)
for p in self.params
]
# generate Rademacher random variables
for v_i in v:
v_i[v_i == 0] = -1
if self.full_dataset:
_, Hv = self.dataloader_hv_product(v)
else:
Hv = hessian_vector_product(self.gradsH, self.params, v)
trace_vhv.append(group_product(Hv, v).cpu().item())
if abs(np.mean(trace_vhv) - trace) / (abs(trace) + 1e-6) < tol:
return trace_vhv
else:
trace = np.mean(trace_vhv)
return trace_vhv
def density(self, iter=100, n_v=1):
"""
compute estimated eigenvalue density using stochastic lanczos algorithm (SLQ)
iter: number of iterations used to compute trace
n_v: number of SLQ runs
"""
device = self.device
eigen_list_full = []
weight_list_full = []
for k in range(n_v):
v = [
torch.randint_like(p, high=2, device=device)
for p in self.params
]
# generate Rademacher random variables
for v_i in v:
v_i[v_i == 0] = -1
v = normalization(v)
# standard lanczos algorithm initlization
v_list = [v]
w_list = []
alpha_list = []
beta_list = []
############### Lanczos
for i in range(iter):
self.model.zero_grad()
w_prime = [torch.zeros(p.size()).to(device) for p in self.params]
if i == 0:
if self.full_dataset:
_, w_prime = self.dataloader_hv_product(v)
else:
w_prime = hessian_vector_product(
self.gradsH, self.params, v)
alpha = group_product(w_prime, v)
alpha_list.append(alpha.cpu().item())
w = group_add(w_prime, v, alpha=-alpha)
w_list.append(w)
else:
beta = torch.sqrt(group_product(w, w))
beta_list.append(beta.cpu().item())
if beta_list[-1] != 0.:
# We should re-orth it
v = orthnormal(w, v_list)
v_list.append(v)
else:
# generate a new vector
w = [torch.randn(p.size()).to(device) for p in self.params]
v = orthnormal(w, v_list)
v_list.append(v)
if self.full_dataset:
_, w_prime = self.dataloader_hv_product(v)
else:
w_prime = hessian_vector_product(
self.gradsH, self.params, v)
alpha = group_product(w_prime, v)
alpha_list.append(alpha.cpu().item())
w_tmp = group_add(w_prime, v, alpha=-alpha)
w = group_add(w_tmp, v_list[-2], alpha=-beta)
T = torch.zeros(iter, iter).to(device)
for i in range(len(alpha_list)):
T[i, i] = alpha_list[i]
if i < len(alpha_list) - 1:
T[i + 1, i] = beta_list[i]
T[i, i + 1] = beta_list[i]
a_, b_ = torch.eig(T, eigenvectors=True)
eigen_list = a_[:, 0]
weight_list = b_[0, :]**2
eigen_list_full.append(list(eigen_list.cpu().numpy()))
weight_list_full.append(list(weight_list.cpu().numpy()))
return eigen_list_full, weight_list_full