This repository was archived by the owner on Apr 8, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdistpro.py
113 lines (92 loc) · 4.27 KB
/
distpro.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright 2021 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
import torch.nn as nn
from torch.autograd import Variable
import copy
def _concat(xs):
return torch.cat([x.view(-1) for x in xs])
class DistPro(object):
def __init__(self, model, args):
self.network_momentum = args.momentum
self.network_weight_decay = args.weight_decay
self.model = model
self.optimizer = torch.optim.Adam(self.model.arch_parameters(),
lr=args.arch_lr, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)
def _compute_unrolled_model(self, input, target, eta, network_optimizer):
loss = self.model._loss(input, target)
theta = _concat(self.model.parameters()).data
try:
moment = _concat(network_optimizer.state[v]['momentum_buffer'] for v in self.model.parameters()).mul_(self.network_momentum)
except:
moment = torch.zeros_like(theta)
dtheta = _concat(torch.autograd.grad(loss, self.model.parameters())).data + self.network_weight_decay*theta
unrolled_model = self._switch_model_from_theta(theta.sub(eta, moment+dtheta))
return unrolled_model
def step(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer, unrolled):
self.optimizer.zero_grad()
if unrolled:
self._backward_step_unrolled(input_train, target_train, input_valid, target_valid, eta, network_optimizer)
else:
self._backward_step(input_valid, target_valid)
self.optimizer.step()
print('parameter alpha: ' + str(list(self.model.arch_parameters())[0].detach().cpu().numpy().tolist()))
def _backward_step(self, input_valid, target_valid):
loss = self.model._loss(input_valid, target_valid, val=True)
loss.backward()
def _backward_step_unrolled(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer):
unrolled_model = self._compute_unrolled_model(input_train, target_train, eta, network_optimizer)
unrolled_loss = unrolled_model._loss(input_valid, target_valid, val=True)
unrolled_loss.backward()
dalpha = [v.grad for v in unrolled_model.arch_parameters()]
vector = [v.grad.data for v in unrolled_model.parameters()]
self._restore_model()
implicit_grads = self._hessian_vector_product(vector, input_train, target_train)
print('gradient alpha: ' + str(implicit_grads[0].detach().cpu().numpy().tolist()))
for g, ig in zip(dalpha, implicit_grads):
g.data.sub_(eta, ig.data)
for v, g in zip(self.model.arch_parameters(), dalpha):
if v.grad is None:
v.grad = Variable(g.data)
else:
v.grad.data.copy_(g.data)
def _switch_model_from_theta(self, theta):
self.model_dict = self.model.state_dict()
model_dict = copy.deepcopy(self.model_dict)
params, offset = {}, 0
for k, v in self.model.named_parameters():
v_length = np.prod(v.size())
params[k] = theta[offset: offset+v_length].view(v.size())
offset += v_length
assert offset == len(theta)
model_dict.update(params)
self.model.load_state_dict(model_dict)
return self.model
def _restore_model(self):
self.model.load_state_dict(self.model_dict)
return self.model
def _hessian_vector_product(self, vector, input, target, r=1e-2):
R = r / _concat(vector).norm()
for p, v in zip(self.model.parameters(), vector):
p.data.add_(R, v)
loss = self.model._loss(input, target)
grads_p = torch.autograd.grad(loss, self.model.arch_parameters())
for p, v in zip(self.model.parameters(), vector):
p.data.sub_(2*R, v)
loss = self.model._loss(input, target)
grads_n = torch.autograd.grad(loss, self.model.arch_parameters())
for p, v in zip(self.model.parameters(), vector):
p.data.add_(R, v)
return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)]