Skip to content

Commit 99dd7a2

Browse files
author
Dan Jia
committed
update DrSpaam with option to predict BEV box
1 parent 1a365cd commit 99dd7a2

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

dr_spaam/dr_spaam/model/dr_spaam.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(
1818
cls_loss=None,
1919
mixup_alpha=0.0,
2020
mixup_w=0.0,
21+
use_box=False,
2122
):
2223
super(DrSpaam, self).__init__()
2324

@@ -44,6 +45,11 @@ def __init__(
4445
# detection layer
4546
self.conv_cls = nn.Conv1d(128, 1, kernel_size=1)
4647
self.conv_reg = nn.Conv1d(128, 2, kernel_size=1)
48+
self._use_box = use_box
49+
if use_box:
50+
self.conv_box = nn.Conv1d(
51+
128, 4, kernel_size=1
52+
) # length, width, sin_rot, cos_rot
4753

4854
# spatial attention
4955
self.gate = _SpatialAttentionMemory(
@@ -68,6 +74,10 @@ def __init__(
6874
nn.init.constant_(m.weight, 1)
6975
nn.init.constant_(m.bias, 0)
7076

77+
@property
78+
def use_box(self):
79+
return self._use_box
80+
7181
def forward(self, x, inference=False):
7282
"""
7383
Args:
@@ -110,7 +120,11 @@ def forward(self, x, inference=False):
110120
pred_cls = self.conv_cls(out).view(B, CT, -1) # (B, CT, cls)
111121
pred_reg = self.conv_reg(out).view(B, CT, 2) # (B, CT, 2)
112122

113-
return pred_cls, pred_reg, sim
123+
if self._use_box:
124+
pred_box = self.conv_box(out).view(B, CT, 4)
125+
return pred_cls, pred_reg, pred_box, sim
126+
else:
127+
return pred_cls, pred_reg, sim
114128

115129
def _conv_and_pool(self, x, conv_block):
116130
out = conv_block(x)
@@ -169,9 +183,11 @@ def forward(self, x_new):
169183
self._memory = x_new
170184
return self._memory, None
171185

186+
# ##########
172187
# NOTE: Ablation study, DR-AM, no spatial attention
173188
# self._memory = self._alpha * x_new + (1.0 - self._alpha) * self._memory
174189
# return self._memory, None
190+
# ##########
175191

176192
n_batch, n_cutout, n_channel, n_pts = x_new.shape
177193

@@ -196,6 +212,7 @@ def forward(self, x_new):
196212
sim = torch.matmul(emb_x, emb_temp.permute(0, 2, 1))
197213

198214
# masked softmax
215+
# TODO replace with gather and scatter
199216
sim = sim - 1e10 * (
200217
1.0 - self.neighbor_masks
201218
) # make sure the out-of-window elements have small values
@@ -204,6 +221,12 @@ def forward(self, x_new):
204221
exps_sum = exps.sum(dim=-1, keepdim=True)
205222
sim = exps / exps_sum
206223

224+
# # NOTE this gather scatter version is only marginally more efficient on memory
225+
# sim_w = torch.gather(sim, 2, self.neighbor_inds.unsqueeze(dim=0))
226+
# sim_w = sim_w.softmax(dim=2)
227+
# sim = torch.zeros_like(sim)
228+
# sim.scatter_(2, self.neighbor_inds.unsqueeze(dim=0), sim_w)
229+
207230
# weighted average on the template
208231
atten_memory = self._memory.view(n_batch, n_cutout, n_channel * n_pts)
209232
atten_memory = torch.matmul(sim, atten_memory)
@@ -230,7 +253,6 @@ def _generate_neighbor_mask(self, x):
230253
)
231254
inds_row = torch.arange(n_cutout).unsqueeze(dim=-1).expand_as(inds_col).long()
232255
inds_full = torch.stack((inds_row, inds_col), dim=2).view(-1, 2)
233-
# self.register_buffer('neighbor_inds', inds_full)
234256

235257
masks = torch.zeros(n_cutout, n_cutout).float()
236258
masks[inds_full[:, 0], inds_full[:, 1]] = 1.0

0 commit comments

Comments
 (0)