Skip to content

Commit

Permalink
support more flexible setting of conv head; slice inputs when batch s…
Browse files Browse the repository at this point in the history
…ize is too large in PFNLayer to avoid bugs (open-mmlab#124)

* support more flexible setting

* slice inputs of nn.Linear when batch size is too large
  • Loading branch information
Gus-Guo authored Jul 4, 2020
1 parent b79844c commit b42d7a9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
32 changes: 18 additions & 14 deletions pcdet/models/backbones_2d/base_bev_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ def __init__(self, model_cfg, input_channels):
super().__init__()
self.model_cfg = model_cfg

assert len(self.model_cfg.LAYER_NUMS) == len(self.model_cfg.LAYER_STRIDES) == \
len(self.model_cfg.NUM_FILTERS) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS)
assert len(self.model_cfg.LAYER_NUMS) == len(self.model_cfg.LAYER_STRIDES) == len(self.model_cfg.NUM_FILTERS)
assert len(self.model_cfg.UPSAMPLE_STRIDES) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS)
layer_nums = self.model_cfg.LAYER_NUMS
layer_strides = self.model_cfg.LAYER_STRIDES
num_filters = self.model_cfg.NUM_FILTERS
Expand Down Expand Up @@ -36,16 +36,16 @@ def __init__(self, model_cfg, input_channels):
nn.ReLU()
])
self.blocks.append(nn.Sequential(*cur_layers))

self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d(
num_filters[idx], num_upsample_filters[idx],
upsample_strides[idx],
stride=upsample_strides[idx], bias=False
),
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
))
if len(upsample_strides) > 0:
self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d(
num_filters[idx], num_upsample_filters[idx],
upsample_strides[idx],
stride=upsample_strides[idx], bias=False
),
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
))

c_in = sum(num_upsample_filters)
if len(upsample_strides) > num_levels:
Expand Down Expand Up @@ -73,12 +73,16 @@ def forward(self, data_dict):

stride = int(spatial_features.shape[2] / x.shape[2])
ret_dict['spatial_features_%dx' % stride] = x
ups.append(self.deblocks[i](x))
if len(self.deblocks) > 0:
ups.append(self.deblocks[i](x))
else:
ups.append(x)

if len(ups) > 1:
x = torch.cat(ups, dim=1)
else:
elif len(ups) == 1:
x = ups[0]

if len(self.deblocks) > len(self.blocks):
x = self.deblocks[-1](x)

Expand Down
10 changes: 9 additions & 1 deletion pcdet/models/backbones_3d/vfe/pillar_vfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,16 @@ def __init__(self,
else:
self.linear = nn.Linear(in_channels, out_channels, bias=True)

self.part = 50000

def forward(self, inputs):
x = self.linear(inputs)
if inputs.shape[0] > self.part:
# nn.Linear performs randomly when batch size is too large
num_parts = inputs.shape[0] // self.part
part_linear_out = [self.linear(inputs[num_part*self.part:(num_part+1)*self.part]) for num_part in range(num_parts+1)]
x = torch.cat(part_linear_out, dim=0)
else:
x = self.linear(inputs)
total_points, voxel_points, channels = x.shape
x = self.norm(x.view(-1, channels)).view(total_points, voxel_points, channels) if self.use_norm else x
x = F.relu(x)
Expand Down

0 comments on commit b42d7a9

Please sign in to comment.