-
Notifications
You must be signed in to change notification settings - Fork 2
/
resnet.py
292 lines (228 loc) · 10.1 KB
/
resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.autograd import Variable
import torch.nn.init as init
def to_var(x, requires_grad=True):
if torch.cuda.is_available():
x = x.cuda()
return Variable(x, requires_grad=requires_grad)
class resnet_attention(nn.Module):
def __init__(self, enc_hid_dim=64, dec_hid_dim=100):
super(resnet_attention, self).__init__()
self.attn = nn.Linear(enc_hid_dim , dec_hid_dim, bias=True)
self.v = nn.Linear(dec_hid_dim, 1, bias=False)
def forward(self, s):
energy = torch.tanh(self.attn(s))
attention = self.v(energy)
return F.softmax(attention, dim=0)
class MetaModule(nn.Module):
def params(self):
for name, param in self.named_params(self):
yield param
def named_leaves(self):
return []
def named_submodules(self):
return []
def named_params(self, curr_module=None, memo=None, prefix=''):
if memo is None:
memo = set()
if hasattr(curr_module, 'named_leaves'):
for name, p in curr_module.named_leaves():
if p is not None and p not in memo:
memo.add(p)
yield prefix + ('.' if prefix else '') + name, p
else:
for name, p in curr_module._parameters.items():
if p is not None and p not in memo:
memo.add(p)
yield prefix + ('.' if prefix else '') + name, p
for mname, module in curr_module.named_children():
submodule_prefix = prefix + ('.' if prefix else '') + mname
for name, p in self.named_params(module, memo, submodule_prefix):
yield name, p
def update_params(self, lr_inner, first_order=False, source_params=None, detach=False):
if source_params is not None:
for tgt, src in zip(self.named_params(self), source_params):
name_t, param_t = tgt
grad = src
if first_order:
grad = to_var(grad.detach().data)
tmp = param_t - lr_inner * grad
self.set_param(self, name_t, tmp)
else:
for name, param in self.named_params(self):
if not detach:
grad = param.grad
if first_order:
grad = to_var(grad.detach().data)
tmp = param - lr_inner * grad
self.set_param(self, name, tmp)
else:
param = param.detach_()
self.set_param(self, name, param)
def set_param(self, curr_mod, name, param):
if '.' in name:
n = name.split('.')
module_name = n[0]
rest = '.'.join(n[1:])
for name, mod in curr_mod.named_children():
if module_name == name:
self.set_param(mod, rest, param)
break
else:
setattr(curr_mod, name, param)
def detach_params(self):
for name, param in self.named_params(self):
self.set_param(self, name, param.detach())
def copy(self, other, same_var=False):
for name, param in other.named_params():
if not same_var:
param = to_var(param.data.clone(), requires_grad=True)
self.set_param(name, param)
class MetaLinear(MetaModule):
def __init__(self, *args, **kwargs):
super().__init__()
ignore = nn.Linear(*args, **kwargs)
self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
def forward(self, x):
return F.linear(x, self.weight, self.bias)
def named_leaves(self):
return [('weight', self.weight), ('bias', self.bias)]
class MetaLinear_Norm(MetaModule):
def __init__(self, *args, **kwargs):
super().__init__()
temp = nn.Linear(*args, **kwargs)
temp.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
self.register_buffer('weight', to_var(temp.weight.data.t(), requires_grad=True))
self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
def forward(self, x):
out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0))
return out
def named_leaves(self):
return [('weight', self.weight)]
class MetaConv2d(MetaModule):
def __init__(self, *args, **kwargs):
super().__init__()
ignore = nn.Conv2d(*args, **kwargs)
self.in_channels = ignore.in_channels
self.out_channels = ignore.out_channels
self.stride = ignore.stride
self.padding = ignore.padding
self.dilation = ignore.dilation
self.groups = ignore.groups
self.kernel_size = ignore.kernel_size
self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
if ignore.bias is not None:
self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
else:
self.register_buffer('bias', None)
def forward(self, x):
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def named_leaves(self):
return [('weight', self.weight), ('bias', self.bias)]
class MetaConvTranspose2d(MetaModule):
def __init__(self, *args, **kwargs):
super().__init__()
ignore = nn.ConvTranspose2d(*args, **kwargs)
self.stride = ignore.stride
self.padding = ignore.padding
self.dilation = ignore.dilation
self.groups = ignore.groups
self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
if ignore.bias is not None:
self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
else:
self.register_buffer('bias', None)
def forward(self, x, output_size=None):
output_padding = self._output_padding(x, output_size)
return F.conv_transpose2d(x, self.weight, self.bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
def named_leaves(self):
return [('weight', self.weight), ('bias', self.bias)]
class MetaBatchNorm2d(MetaModule):
def __init__(self, *args, **kwargs):
super().__init__()
ignore = nn.BatchNorm2d(*args, **kwargs)
self.num_features = ignore.num_features
self.eps = ignore.eps
self.momentum = ignore.momentum
self.affine = ignore.affine
self.track_running_stats = ignore.track_running_stats
if self.affine:
self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(self.num_features))
self.register_buffer('running_var', torch.ones(self.num_features))
else:
self.register_parameter('running_mean', None)
self.register_parameter('running_var', None)
def forward(self, x):
return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
self.training or not self.track_running_stats, self.momentum, self.eps)
def named_leaves(self):
return [('weight', self.weight), ('bias', self.bias)]
def _weights_init(m):
classname = m.__class__.__name__
if isinstance(m, MetaLinear) or isinstance(m, MetaConv2d):
init.kaiming_normal(m.weight)
class LambdaLayer(MetaModule):
def __init__(self, lambd):
super(LambdaLayer, self).__init__()
self.lambd = lambd
def forward(self, x):
return self.lambd(x)
class BasicBlock(MetaModule):
expansion = 1
def __init__(self, in_planes, planes, stride=1, option='A'):
super(BasicBlock, self).__init__()
self.conv1 = MetaConv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = MetaBatchNorm2d(planes)
self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = MetaBatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != planes:
if option == 'A':
self.shortcut = LambdaLayer(lambda x:
F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
elif option == 'B':
self.shortcut = nn.Sequential(
MetaConv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
MetaBatchNorm2d(self.expansion * planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet32(MetaModule):
def __init__(self, num_classes, block=BasicBlock, num_blocks=[5, 5, 5]):
super(ResNet32, self).__init__()
self.in_planes = 16
self.conv1 = MetaConv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = MetaBatchNorm2d(16)
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
self.linear = MetaLinear(64, num_classes)
self.apply(_weights_init)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = F.avg_pool2d(out, out.size()[3])
out = out.view(out.size(0), -1)
y = self.linear(out)
return out, y