-
Notifications
You must be signed in to change notification settings - Fork 43
/
models.py
451 lines (335 loc) · 15.1 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair, _quadruple
from torch.autograd import Variable
import config
ENCODER_PARAMS = config.ENCODER_PARAMS
DECODER_PARAMS = config.DECODER_PARAMS
class MedianPool2d(nn.Module):
""" Median pool (usable as median filter when stride=1) module.
Args:
kernel_size: size of pooling kernel, int or 2-tuple
stride: pool stride, int or 2-tuple
padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad
same: override padding and enforce same padding, boolean
"""
def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
super(MedianPool2d, self).__init__()
self.k = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _quadruple(padding) # convert to l, r, g, b
self.same = same
def _padding(self, x):
if self.same:
ih, iw = x.size()[2:]
if ih % self.stride[0] == 0:
ph = max(self.k[0] - self.stride[0], 0)
else:
ph = max(self.k[0] - (ih % self.stride[0]), 0)
if iw % self.stride[1] == 0:
pw = max(self.k[1] - self.stride[1], 0)
else:
pw = max(self.k[1] - (iw % self.stride[1]), 0)
pl = pw // 2
pr = pw - pl
pt = ph // 2
pb = ph - pt
padding = (pl, pr, pt, pb)
else:
padding = self.padding
return padding
def forward(self, x):
# using existing pytorch functions and tensor ops so that we get autograd,
# would likely be more efficient to implement from scratch at C/Cuda level
x = F.pad(x, self._padding(x), mode='reflect')
x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]
return x
class EnetInitialBlock(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(
6, 32, (3, 3),
stride=2, padding=1, bias=True)
# self.pool = nn.MaxPool2d(2, stride=2)
self.batch_norm = nn.BatchNorm2d(32, eps=1e-3)
self.actf = nn.PReLU()
def forward(self, input):
output = self.conv(input)
output = self.batch_norm(output)
return self.actf(output)
class EnetEncoderMainPath(nn.Module):
def __init__(self, internal_scale=None, use_relu=None, asymmetric=None, dilated=None, input_channels=None,
output_channels=None, downsample=None, dropout_prob=None):
super().__init__()
internal_channels = output_channels // internal_scale
input_stride = downsample and 2 or 1
self.__dict__.update(locals())
del self.self
self.input_conv = nn.Conv2d(
input_channels, internal_channels, input_stride,
stride=input_stride, padding=0, bias=False)
self.input_batch_norm = nn.BatchNorm2d(internal_channels, eps=1e-03)
self.middle_conv = nn.Conv2d(
internal_channels, internal_channels, 3, stride=1, bias=True,
dilation=1 if (not dilated or dilated is None) else dilated,
padding=1 if (not dilated or dilated is None) else dilated)
self.middle_batch_norm = nn.BatchNorm2d(internal_channels, eps=1e-03)
self.output_conv = nn.Conv2d(
internal_channels, output_channels, 1,
stride=1, padding=0, bias=False)
self.output_batch_norm = nn.BatchNorm2d(output_channels, eps=1e-03)
self.dropout = nn.Dropout2d(dropout_prob)
self.input_actf = nn.PReLU()
self.middle_actf = nn.PReLU()
def forward(self, input):
output = self.input_conv(input)
output = self.input_batch_norm(output)
output = self.input_actf(output)
output = self.middle_conv(output)
output = self.middle_batch_norm(output)
output = self.middle_actf(output)
output = self.output_conv(output)
output = self.output_batch_norm(output)
output = self.dropout(output)
return output
class EnetEncoderOtherPath(nn.Module):
def __init__(self, internal_scale=None, use_relu=None, asymmetric=None, dilated=None, input_channels=None,
output_channels=None, downsample=None, **kwargs):
super().__init__()
self.__dict__.update(locals())
del self.self
if downsample:
self.pool = nn.MaxPool2d(2, stride=2, return_indices=True)
def forward(self, input):
output = input
if self.downsample:
output, self.indices = self.pool(input)
if self.output_channels != self.input_channels:
new_size = [1, 1, 1, 1]
new_size[1] = self.output_channels // self.input_channels
output = output.repeat(*new_size)
return output
class EnetEncoderModule(nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.main = EnetEncoderMainPath(**kwargs)
self.other = EnetEncoderOtherPath(**kwargs)
self.actf = nn.PReLU()
def forward(self, input):
main = self.main(input)
other = self.other(input)
# print("EnetEncoderModule main size:", main.size())
# print("EnetEncoderModule other size:", other.size())
return self.actf(main + other)
class EnetEncoder(nn.Module):
def __init__(self, params, nclasses):
super().__init__()
self.initial_block = EnetInitialBlock()
self.layers = []
for i, params in enumerate(params):
layer_name = 'encoder_{:02d}'.format(i)
layer = EnetEncoderModule(**params)
super().__setattr__(layer_name, layer)
self.layers.append(layer)
self.output_conv = nn.Conv2d(
128, nclasses, 1,
stride=1, padding=0, bias=True)
def forward(self, input):
output = self.initial_block(input)
for layer in self.layers:
output = layer(output)
return output
class EnetDecoderMainPath(nn.Module):
def __init__(self, input_channels=None, output_channels=None, upsample=None, pooling_module=None):
super().__init__()
internal_channels = output_channels // 4
input_stride = 2 if upsample is True else 1
self.__dict__.update(locals())
del self.self
self.input_conv = nn.Conv2d(
input_channels, internal_channels, 1,
stride=1, padding=0, bias=False)
self.input_batch_norm = nn.BatchNorm2d(internal_channels, eps=1e-03)
if not upsample:
self.middle_conv = nn.Conv2d(
internal_channels, internal_channels, 3,
stride=1, padding=1, bias=True)
else:
self.middle_conv = nn.ConvTranspose2d(
internal_channels, internal_channels, 3,
stride=2, padding=1, output_padding=1,
bias=True)
self.middle_batch_norm = nn.BatchNorm2d(internal_channels, eps=1e-03)
self.output_conv = nn.Conv2d(
internal_channels, output_channels, 1,
stride=1, padding=0, bias=False)
self.output_batch_norm = nn.BatchNorm2d(output_channels, eps=1e-03)
self.input_actf = nn.PReLU()
self.middle_actf = nn.PReLU()
def forward(self, input):
output = self.input_conv(input)
output = self.input_batch_norm(output)
output = self.input_actf(output)
output = self.middle_conv(output)
output = self.middle_batch_norm(output)
output = self.middle_actf(output)
output = self.output_conv(output)
output = self.output_batch_norm(output)
return output
class EnetDecoderOtherPath(nn.Module):
def __init__(self, input_channels=None, output_channels=None, upsample=None, pooling_module=None):
super().__init__()
self.__dict__.update(locals())
del self.self
if output_channels != input_channels or upsample:
self.conv = nn.Conv2d(
input_channels, output_channels, 1,
stride=1, padding=0, bias=False)
self.batch_norm = nn.BatchNorm2d(output_channels, eps=1e-03)
if upsample and pooling_module:
self.unpool = nn.MaxUnpool2d(2, stride=2, padding=0)
def forward(self, input):
output = input
if self.output_channels != self.input_channels or self.upsample:
output = self.conv(output)
output = self.batch_norm(output)
if self.upsample and self.pooling_module:
output_size = list(output.size())
output_size[2] *= 2
output_size[3] *= 2
output = self.unpool(
output, self.pooling_module.indices,
output_size=output_size)
return output
class EnetDecoderModule(nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.main = EnetDecoderMainPath(**kwargs)
self.other = EnetDecoderOtherPath(**kwargs)
self.actf = nn.PReLU()
def forward(self, input):
main = self.main(input)
other = self.other(input)
# print("EnetDecoderModule main size:", main.size())
# print("EnetDecoderModule other size:", other.size())
return self.actf(main + other)
class EnetDecoder(nn.Module):
def __init__(self, params, nclasses, encoder):
super().__init__()
self.encoder = encoder
self.pooling_modules = []
for mod in self.encoder.modules():
try:
if mod.other.downsample:
self.pooling_modules.append(mod.other)
except AttributeError:
pass
self.layers = []
for i, params in enumerate(params):
if params['upsample']:
params['pooling_module'] = self.pooling_modules.pop(-1)
layer = EnetDecoderModule(**params)
self.layers.append(layer)
layer_name = 'decoder{:02d}'.format(i)
super().__setattr__(layer_name, layer)
self.output_conv = nn.ConvTranspose2d(
32, nclasses, 2,
stride=2, padding=0, output_padding=0, bias=True)
def forward(self, input):
output = input
for layer in self.layers:
output = layer(output)
output = self.output_conv(output)
return output
class EnetGnn(nn.Module):
def __init__(self, mlp_num_layers,use_gpu):
super().__init__()
self.median_pool = MedianPool2d(kernel_size=8, stride=8, padding=0, same=False)
self.g_rnn_layers = nn.ModuleList([nn.Linear(128, 128) for l in range(mlp_num_layers)])
self.g_rnn_actfs = nn.ModuleList([nn.PReLU() for l in range(mlp_num_layers)])
self.q_rnn_layer = nn.Linear(256, 128)
self.q_rnn_actf = nn.PReLU()
self.output_conv = nn.Conv2d(256, 128, 3, stride=1, padding=1, bias=True)
self.use_gpu = use_gpu
# adapted from https://discuss.pytorch.org/t/build-your-own-loss-function-in-pytorch/235/6
# (x - y)^2 = x^2 - 2*x*y + y^2
def get_knn_indices(self, batch_mat, k):
r = torch.bmm(batch_mat, batch_mat.permute(0, 2, 1))
N = r.size()[0]
HW = r.size()[1]
if self.use_gpu:
batch_indices = torch.zeros((N, HW, k)).cuda()
else:
batch_indices = torch.zeros((N, HW, k))
for idx, val in enumerate(r):
# get the diagonal elements
diag = val.diag().unsqueeze(0)
diag = diag.expand_as(val)
# compute the distance matrix
D = (diag + diag.t() - 2 * val).sqrt()
topk, indices = torch.topk(D, k=k, largest=False)
batch_indices[idx] = indices.data
return batch_indices
def forward(self, cnn_encoder_output, original_input, gnn_iterations, k, xy):
# extract for convenience
N = cnn_encoder_output.size()[0]
C = cnn_encoder_output.size()[1]
H = cnn_encoder_output.size()[2]
W = cnn_encoder_output.size()[3]
K = k
# extract and resize depth image as horizontal disparity channel from HHA encoded image
depth = original_input[:, 3, :, :] # N 8H 8W
depth = depth.view(depth.size()[0], 1, depth.size()[1], depth.size()[2])
depth_resize = self.median_pool(depth) # N 1 H W
x_coords = xy[:, 0, :, :]
x_coords = x_coords.view(x_coords.size()[0], 1, x_coords.size()[1], x_coords.size()[2])
y_coords = xy[:, 1, :, :]
y_coords = y_coords.view(y_coords.size()[0], 1, y_coords.size()[1], y_coords.size()[2])
x_coords = self.median_pool(x_coords) # N 1 H W
y_coords = self.median_pool(y_coords) # N 1 H W
# 3D projection --> point cloud
proj_3d = torch.cat((x_coords, y_coords, depth_resize), 1)
proj_3d = proj_3d.view(N, 3, H*W).permute(0, 2, 1).contiguous() # N H*W 3
# get k nearest neighbors
knn = self.get_knn_indices(proj_3d, k=K) # N HW K
knn = knn.view(N, H*W*K).long() # N HWK
# prepare CNN encoded features for RNN
h = cnn_encoder_output # N C H W
h = h.permute(0, 2, 3, 1).contiguous() # N H W C
h = h.view(N, (H * W), C) # N HW C
# aggregate and iterate messages in m, keep original CNN features h for later
m = h.clone() # N HW C
# loop over timestamps to unroll
for i in range(gnn_iterations):
# do this for every samplein batch, not nice, but I don't know how to use index_select batchwise
for n in range(N):
# fetch features from nearest neighbors
neighbor_features = torch.index_select(h[n], 0, Variable(knn[n])).view(H*W, K, C) # H*W K C
# run neighbor features through MLP g and activation function
for idx, g_layer in enumerate(self.g_rnn_layers):
neighbor_features = self.g_rnn_actfs[idx](g_layer(neighbor_features)) # HW K C
# average over activated neighbors
m[n] = torch.mean(neighbor_features, dim=1) # HW C
# concatenate current state with messages
concat = torch.cat((h, m), 2) # N HW 2C
# get new features by running MLP q and activation function
h = self.q_rnn_actf(self.q_rnn_layer(concat)) # N HW C
# format RNN activations back to image, concatenate original CNN embedding, return
h = h.view(N, H, W, C).permute(0, 3, 1, 2).contiguous() # N C H W
output = self.output_conv(torch.cat((cnn_encoder_output, h), 1)) # N 2C H W
return output
class Model(nn.Module):
def __init__(self, nclasses, mlp_num_layers,use_gpu):
super().__init__()
self.encoder = EnetEncoder(ENCODER_PARAMS, nclasses)
self.gnn = EnetGnn(mlp_num_layers,use_gpu)
self.decoder = EnetDecoder(DECODER_PARAMS, nclasses, self.encoder)
def forward(self, input, gnn_iterations, k, xy, use_gnn, only_encode=False):
x = self.encoder.forward(input)
if only_encode:
return x
if use_gnn:
x = self.gnn.forward(x, input, gnn_iterations, k, xy)
return self.decoder.forward(x)