@@ -18,6 +18,7 @@ def __init__(
18
18
cls_loss = None ,
19
19
mixup_alpha = 0.0 ,
20
20
mixup_w = 0.0 ,
21
+ use_box = False ,
21
22
):
22
23
super (DrSpaam , self ).__init__ ()
23
24
@@ -44,6 +45,11 @@ def __init__(
44
45
# detection layer
45
46
self .conv_cls = nn .Conv1d (128 , 1 , kernel_size = 1 )
46
47
self .conv_reg = nn .Conv1d (128 , 2 , kernel_size = 1 )
48
+ self ._use_box = use_box
49
+ if use_box :
50
+ self .conv_box = nn .Conv1d (
51
+ 128 , 4 , kernel_size = 1
52
+ ) # length, width, sin_rot, cos_rot
47
53
48
54
# spatial attention
49
55
self .gate = _SpatialAttentionMemory (
@@ -68,6 +74,10 @@ def __init__(
68
74
nn .init .constant_ (m .weight , 1 )
69
75
nn .init .constant_ (m .bias , 0 )
70
76
77
+ @property
78
+ def use_box (self ):
79
+ return self ._use_box
80
+
71
81
def forward (self , x , inference = False ):
72
82
"""
73
83
Args:
@@ -110,7 +120,11 @@ def forward(self, x, inference=False):
110
120
pred_cls = self .conv_cls (out ).view (B , CT , - 1 ) # (B, CT, cls)
111
121
pred_reg = self .conv_reg (out ).view (B , CT , 2 ) # (B, CT, 2)
112
122
113
- return pred_cls , pred_reg , sim
123
+ if self ._use_box :
124
+ pred_box = self .conv_box (out ).view (B , CT , 4 )
125
+ return pred_cls , pred_reg , pred_box , sim
126
+ else :
127
+ return pred_cls , pred_reg , sim
114
128
115
129
def _conv_and_pool (self , x , conv_block ):
116
130
out = conv_block (x )
@@ -169,9 +183,11 @@ def forward(self, x_new):
169
183
self ._memory = x_new
170
184
return self ._memory , None
171
185
186
+ # ##########
172
187
# NOTE: Ablation study, DR-AM, no spatial attention
173
188
# self._memory = self._alpha * x_new + (1.0 - self._alpha) * self._memory
174
189
# return self._memory, None
190
+ # ##########
175
191
176
192
n_batch , n_cutout , n_channel , n_pts = x_new .shape
177
193
@@ -196,6 +212,7 @@ def forward(self, x_new):
196
212
sim = torch .matmul (emb_x , emb_temp .permute (0 , 2 , 1 ))
197
213
198
214
# masked softmax
215
+ # TODO replace with gather and scatter
199
216
sim = sim - 1e10 * (
200
217
1.0 - self .neighbor_masks
201
218
) # make sure the out-of-window elements have small values
@@ -204,6 +221,12 @@ def forward(self, x_new):
204
221
exps_sum = exps .sum (dim = - 1 , keepdim = True )
205
222
sim = exps / exps_sum
206
223
224
+ # # NOTE this gather scatter version is only marginally more efficient on memory
225
+ # sim_w = torch.gather(sim, 2, self.neighbor_inds.unsqueeze(dim=0))
226
+ # sim_w = sim_w.softmax(dim=2)
227
+ # sim = torch.zeros_like(sim)
228
+ # sim.scatter_(2, self.neighbor_inds.unsqueeze(dim=0), sim_w)
229
+
207
230
# weighted average on the template
208
231
atten_memory = self ._memory .view (n_batch , n_cutout , n_channel * n_pts )
209
232
atten_memory = torch .matmul (sim , atten_memory )
@@ -230,7 +253,6 @@ def _generate_neighbor_mask(self, x):
230
253
)
231
254
inds_row = torch .arange (n_cutout ).unsqueeze (dim = - 1 ).expand_as (inds_col ).long ()
232
255
inds_full = torch .stack ((inds_row , inds_col ), dim = 2 ).view (- 1 , 2 )
233
- # self.register_buffer('neighbor_inds', inds_full)
234
256
235
257
masks = torch .zeros (n_cutout , n_cutout ).float ()
236
258
masks [inds_full [:, 0 ], inds_full [:, 1 ]] = 1.0
0 commit comments