-
Notifications
You must be signed in to change notification settings - Fork 175
/
ssn_models.py
395 lines (336 loc) · 17.4 KB
/
ssn_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
import torch
from torch import nn
from transforms import *
import torchvision.models
from ops.ssn_ops import Identity, StructuredTemporalPyramidPooling
class SSN(torch.nn.Module):
def __init__(self, num_class,
starting_segment, course_segment, ending_segment, modality,
base_model='resnet101', new_length=None,
dropout=0.8,
crop_num=1, no_regression=False, test_mode=False,
stpp_cfg=(1, (1, 2), 1), bn_mode='frozen'):
super(SSN, self).__init__()
self.modality = modality
self.num_segments = starting_segment + course_segment + ending_segment
self.starting_segment = starting_segment
self.course_segment = course_segment
self.ending_segment = ending_segment
self.reshape = True
self.dropout = dropout
self.crop_num = crop_num
self.with_regression = not no_regression
self.test_mode = test_mode
self.bn_mode=bn_mode
if new_length is None:
self.new_length = 1 if modality == "RGB" else 5
else:
self.new_length = new_length
print(("""
Initializing SSN with base model: {}.
SSN Configurations:
input_modality: {}
starting_segments: {}
course_segments: {}
ending_segments: {}
num_segments: {}
new_length: {}
dropout_ratio: {}
loc. regression: {}
bn_mode: {}
stpp_configs: {}
""".format(base_model, self.modality,
self.starting_segment, self.course_segment, self.ending_segment,
self.num_segments, self.new_length, self.dropout, 'ON' if self.with_regression else "OFF",
self.bn_mode, stpp_cfg)))
self._prepare_base_model(base_model)
feature_dim = self._prepare_ssn(num_class, stpp_cfg)
if self.modality == 'Flow':
print("Converting the ImageNet model to a flow init model")
self.base_model = self._construct_flow_model(self.base_model)
print("Done. Flow model ready...")
elif self.modality == 'RGBDiff':
print("Converting the ImageNet model to RGB+Diff init model")
self.base_model = self._construct_diff_model(self.base_model)
print("Done. RGBDiff model ready.")
self.prepare_bn()
def _prepare_ssn(self, num_class, stpp_cfg):
feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_features
if self.dropout == 0:
setattr(self.base_model, self.base_model.last_layer_name, Identity())
else:
setattr(self.base_model, self.base_model.last_layer_name, nn.Dropout(p=self.dropout))
self.stpp = StructuredTemporalPyramidPooling(feature_dim, True, configs=stpp_cfg)
self.activity_fc = nn.Linear(self.stpp.activity_feat_dim(), num_class + 1)
self.completeness_fc = nn.Linear(self.stpp.completeness_feat_dim(), num_class)
nn.init.normal(self.activity_fc.weight.data, 0, 0.001)
nn.init.constant(self.activity_fc.bias.data, 0)
nn.init.normal(self.completeness_fc.weight.data, 0, 0.001)
nn.init.constant(self.completeness_fc.bias.data, 0)
self.test_fc = None
if self.with_regression:
self.regressor_fc = nn.Linear(self.stpp.completeness_feat_dim(), 2 * num_class)
nn.init.normal(self.regressor_fc.weight.data, 0, 0.001)
nn.init.constant(self.regressor_fc.bias.data, 0)
else:
self.regressor_fc = None
return feature_dim
def prepare_bn(self):
if self.bn_mode == 'partial':
print("Freezing BatchNorm2D except the first one.")
self.freeze_count = 2
elif self.bn_mode == 'frozen':
print("Freezing all BatchNorm2D layers")
self.freeze_count = 1
elif self.bn_mode == 'full':
self.freeze_count = None
else:
raise ValueError("unknown bn mode")
def _prepare_base_model(self, base_model):
if 'resnet' in base_model or 'vgg' in base_model:
self.base_model = getattr(torchvision.models, base_model)(True)
self.base_model.last_layer_name = 'fc'
self.input_size = 224
self.input_mean = [0.485, 0.456, 0.406]
self.input_std = [0.229, 0.224, 0.225]
if self.modality == 'Flow':
self.input_mean = [0.5]
self.input_std = [np.mean(self.input_std)]
elif self.modality == 'RGBDiff':
self.input_mean = [0.485, 0.456, 0.406] + [0] * 3 * self.new_length
self.input_std = self.input_std + [np.mean(self.input_std) * 2] * 3 * self.new_length
elif base_model == 'BNInception':
import model_zoo
self.base_model = getattr(model_zoo, base_model)()
self.base_model.last_layer_name = 'fc'
self.input_size = 224
self.input_mean = [104, 117, 128]
self.input_std = [1]
if self.modality == 'Flow':
self.input_mean = [128]
elif self.modality == 'RGBDiff':
self.input_mean = self.input_mean * (1 + self.new_length)
elif base_model == 'InceptionV3':
import model_zoo
self.base_model = getattr(model_zoo, base_model)()
self.base_model.last_layer_name = 'top_cls_fc'
self.input_size = 299
self.input_mean = [104, 117, 128]
self.input_std = [1]
if self.modality == 'Flow':
self.input_mean = [128]
elif self.modality == 'RGBDiff':
self.input_mean = self.input_mean * (1 + self.new_length)
elif 'inception' in base_model:
import model_zoo
self.base_model = getattr(model_zoo, base_model)()
self.base_model.last_layer_name = 'classif'
self.input_size = 299
self.input_mean = [0.5]
self.input_std = [0.5]
else:
raise ValueError('Unknown base model: {}'.format(base_model))
def train(self, mode=True):
"""
Override the default train() to freeze the BN parameters
:return:
"""
super(SSN, self).train(mode)
count = 0
if self.freeze_count is None:
return
for m in self.base_model.modules():
if isinstance(m, nn.BatchNorm2d):
count += 1
if count >= self.freeze_count:
m.eval()
# shutdown update in frozen mode
m.weight.requires_grad = False
m.bias.requires_grad = False
def prepare_test_fc(self):
self.test_fc = nn.Linear(self.activity_fc.in_features,
self.activity_fc.out_features
+ self.completeness_fc.out_features * self.stpp.feat_multiplier
+ (self.regressor_fc.out_features * self.stpp.feat_multiplier if self.with_regression else 0))
reorg_comp_weight = self.completeness_fc.weight.data.view(
self.completeness_fc.out_features, self.stpp.feat_multiplier, self.activity_fc.in_features).transpose(0, 1)\
.contiguous().view(-1, self.activity_fc.in_features)
reorg_comp_bias = self.completeness_fc.bias.data.view(1, -1).expand(
self.stpp.feat_multiplier, self.completeness_fc.out_features).contiguous().view(-1) / self.stpp.feat_multiplier
weight = torch.cat((self.activity_fc.weight.data, reorg_comp_weight))
bias = torch.cat((self.activity_fc.bias.data, reorg_comp_bias))
if self.with_regression:
reorg_reg_weight = self.regressor_fc.weight.data.view(
self.regressor_fc.out_features, self.stpp.feat_multiplier, self.activity_fc.in_features).transpose(0, 1) \
.contiguous().view(-1, self.activity_fc.in_features)
reorg_reg_bias = self.regressor_fc.bias.data.view(1, -1).expand(
self.stpp.feat_multiplier, self.regressor_fc.out_features).contiguous().view(-1) / self.stpp.feat_multiplier
weight = torch.cat((weight, reorg_reg_weight))
bias = torch.cat((bias, reorg_reg_bias))
self.test_fc.weight.data = weight
self.test_fc.bias.data = bias
def get_optim_policies(self):
first_conv_weight = []
first_conv_bias = []
normal_weight = []
normal_bias = []
bn = []
conv_cnt = 0
bn_cnt = 0
linear_cnt = 0
for m in self.modules():
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv1d):
ps = list(m.parameters())
conv_cnt += 1
if conv_cnt == 1:
first_conv_weight.append(ps[0])
if len(ps) == 2:
first_conv_bias.append(ps[1])
else:
normal_weight.append(ps[0])
if len(ps) == 2:
normal_bias.append(ps[1])
elif isinstance(m, torch.nn.Linear):
ps = list(m.parameters())
linear_cnt += 1
normal_weight.append(ps[0])
if len(ps) == 2:
normal_bias.append(ps[1])
elif isinstance(m, torch.nn.BatchNorm1d):
bn.extend(list(m.parameters()))
elif isinstance(m, torch.nn.BatchNorm2d):
# BN layers are all frozen in SSN
bn_cnt += 1
elif len(m._modules) == 0:
if len(list(m.parameters())) > 0:
raise ValueError("New atomic module type: {}. Need to give it a learning policy".format(type(m)))
return [
{'params': first_conv_weight, 'lr_mult': 1, 'decay_mult': 1,
'name': "first_conv_weight"},
{'params': first_conv_bias, 'lr_mult': 2, 'decay_mult': 0,
'name': "first_conv_bias"},
{'params': normal_weight, 'lr_mult': 1, 'decay_mult': 1,
'name': "normal_weight"},
{'params': normal_bias, 'lr_mult': 2, 'decay_mult': 0,
'name': "normal_bias"},
{'params': bn, 'lr_mult': 1, 'decay_mult': 0,
'name': "BN scale/shift"},
]
def forward(self, input, aug_scaling, target, reg_target, prop_type):
if not self.test_mode:
return self.train_forward(input, aug_scaling, target, reg_target, prop_type)
else:
return self.test_forward(input)
def train_forward(self, input, aug_scaling, target, reg_target, prop_type):
sample_len = (3 if self.modality == "RGB" else 2) * self.new_length
if self.modality == 'RGBDiff':
sample_len = 3 * self.new_length
input = self._get_diff(input)
base_out = self.base_model(input.view((-1, sample_len) + input.size()[-2:]))
activity_ft, completeness_ft = self.stpp(base_out, aug_scaling, [self.starting_segment,
self.starting_segment + self.course_segment,
self.num_segments])
raw_act_fc = self.activity_fc(activity_ft)
raw_comp_fc = self.completeness_fc(completeness_ft)
type_data = prop_type.view(-1).data
act_indexer = ((type_data == 0) + (type_data == 2)).nonzero().squeeze()
comp_indexer = ((type_data == 0) + (type_data == 1)).nonzero().squeeze()
target = target.view(-1)
if self.with_regression:
reg_target = reg_target.view(-1, 2)
reg_indexer = (type_data == 0).nonzero().squeeze()
raw_regress_fc = self.regressor_fc(completeness_ft).view(-1, self.completeness_fc.out_features, 2)
return raw_act_fc[act_indexer, :], target[act_indexer], \
raw_comp_fc[comp_indexer, :], target[comp_indexer], \
raw_regress_fc[reg_indexer, :, :], target[reg_indexer], reg_target[reg_indexer, :]
else:
return raw_act_fc[act_indexer, :], target[act_indexer], \
raw_comp_fc[comp_indexer, :], target[comp_indexer]
def test_forward(self, input):
sample_len = (3 if self.modality == "RGB" else 2) * self.new_length
if self.modality == 'RGBDiff':
sample_len = 3 * self.new_length
input = self._get_diff(input)
base_out = self.base_model(input.view((-1, sample_len) + input.size()[-2:]))
return self.test_fc(base_out), base_out
def _get_diff(self, input, keep_rgb=False):
input_c = 3 if self.modality in ["RGB", "RGBDiff"] else 2
input_view = input.view((-1, self.num_segments, self.new_length + 1, input_c,) + input.size()[2:])
if keep_rgb:
new_data = input_view.clone()
else:
new_data = input_view[:, :, 1:, :, :, :].clone()
for x in reversed(list(range(1, self.new_length + 1))):
if keep_rgb:
new_data[:, :, x, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :]
else:
new_data[:, :, x - 1, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :]
return new_data
def _construct_flow_model(self, base_model):
# modify the convolution layers
# Torch models are usually defined in a hierarchical way.
# nn.modules.children() return all sub modules in a DFS manner
modules = list(self.base_model.modules())
first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0]
conv_layer = modules[first_conv_idx]
container = modules[first_conv_idx - 1]
# modify parameters, assume the first blob contains the convolution kernels
params = [x.clone() for x in conv_layer.parameters()]
kernel_size = params[0].size()
new_kernel_size = kernel_size[:1] + (2 * self.new_length,) + kernel_size[2:]
new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()
new_conv = nn.Conv2d(2 * self.new_length, conv_layer.out_channels,
conv_layer.kernel_size, conv_layer.stride, conv_layer.padding,
bias=True if len(params) == 2 else False)
new_conv.weight.data = new_kernels
if len(params) == 2:
new_conv.bias.data = params[1].data # add bias if neccessary
layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name
# replace the first convlution layer
setattr(container, layer_name, new_conv)
return base_model
def _construct_diff_model(self, base_model, keep_rgb=False):
# modify the convolution layers
# Torch models are usually defined in a hierarchical way.
# nn.modules.children() return all sub modules in a DFS manner
modules = list(self.base_model.modules())
first_conv_idx = filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules))))[0]
conv_layer = modules[first_conv_idx]
container = modules[first_conv_idx - 1]
# modify parameters, assume the first blob contains the convolution kernels
params = [x.clone() for x in conv_layer.parameters()]
kernel_size = params[0].size()
if not keep_rgb:
new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:]
new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()
else:
new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:]
new_kernels = torch.cat((params[0].data, params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()),
1)
new_kernel_size = kernel_size[:1] + (3 + 3 * self.new_length,) + kernel_size[2:]
new_conv = nn.Conv2d(new_kernel_size[1], conv_layer.out_channels,
conv_layer.kernel_size, conv_layer.stride, conv_layer.padding,
bias=True if len(params) == 2 else False)
new_conv.weight.data = new_kernels
if len(params) == 2:
new_conv.bias.data = params[1].data # add bias if neccessary
layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name
# replace the first convolution layer
setattr(container, layer_name, new_conv)
return base_model
@property
def crop_size(self):
return self.input_size
@property
def scale_size(self):
return self.input_size * 256 // 224
def get_augmentation(self):
if self.modality == 'RGB':
return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66]),
GroupRandomHorizontalFlip(is_flow=False)])
elif self.modality == 'Flow':
return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]),
GroupRandomHorizontalFlip(is_flow=True)])
elif self.modality == 'RGBDiff':
return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]),
GroupRandomHorizontalFlip(is_flow=False)])