Skip to content

Commit

Permalink
Some detailed modifications for Argo2 and VoxelNeXt (open-mmlab#1327)
Browse files Browse the repository at this point in the history
* Add files via upload

* Delete cbgs_voxel01_voxelnext_headkernel3.yaml

* Delete voxelnext_ioubranch.yaml
  • Loading branch information
yukang2017 authored Apr 24, 2023
1 parent 81763e7 commit ad9c25c
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 150 deletions.
26 changes: 13 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,18 +176,19 @@ By default, all models are trained with **a single frame** of **20% data (~32k f

Here we also provide the performance of several models trained on the full training set (refer to the paper of [PV-RCNN++](https://arxiv.org/abs/2102.00463)):

| Performance@(train with 100\% Data) | Vec_L1 | Vec_L2 | Ped_L1 | Ped_L2 | Cyc_L1 | Cyc_L2 |
|---------------------------------------------|----------:|:-------:|:-------:|:-------:|:-------:|:-------:|
| [SECOND](tools/cfgs/waymo_models/second.yaml) | 72.27/71.69 | 63.85/63.33 | 68.70/58.18 | 60.72/51.31 | 60.62/59.28 | 58.34/57.05 |
| [CenterPoint-Pillar](tools/cfgs/waymo_models/centerpoint_pillar_1x.yaml)| 73.37/72.86 | 65.09/64.62 | 75.35/65.11 | 67.61/58.25 | 67.76/66.22 | 65.25/63.77 |
| [Part-A2-Anchor](tools/cfgs/waymo_models/PartA2.yaml) | 77.05/76.51 | 68.47/67.97 | 75.24/66.87 | 66.18/58.62 | 68.60/67.36 | 66.13/64.93 |
| [VoxelNeXt-2D](tools/cfgs/waymo_models/voxelnext2d_ioubranch.yaml) | 77.94/77.47 |69.68/69.25 |80.24/73.47 |72.23/65.88 |73.33/72.20 |70.66/69.56 |
| [PV-RCNN (CenterHead)](tools/cfgs/waymo_models/pv_rcnn_with_centerhead_rpn.yaml) | 78.00/77.50 | 69.43/68.98 | 79.21/73.03 | 70.42/64.72 | 71.46/70.27 | 68.95/67.79 |
| [PV-RCNN++](tools/cfgs/waymo_models/pv_rcnn_plusplus.yaml) | 79.10/78.63 | 70.34/69.91 | 80.62/74.62 | 71.86/66.30 | 73.49/72.38 | 70.70/69.62 |
| [PV-RCNN++ (ResNet)](tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet.yaml) | 79.25/78.78 | 70.61/70.18 | 81.83/76.28 | 73.17/68.00 | 73.72/72.66 | 71.21/70.19 |
| Performance@(train with 100\% Data) | Vec_L1 | Vec_L2 | Ped_L1 | Ped_L2 | Cyc_L1 | Cyc_L2 |
|-------------------------------------------------------------------------------------------|----------:|:-------:|:-------:|:-------:|:-------:|:-------:|
| [SECOND](tools/cfgs/waymo_models/second.yaml) | 72.27/71.69 | 63.85/63.33 | 68.70/58.18 | 60.72/51.31 | 60.62/59.28 | 58.34/57.05 |
| [CenterPoint-Pillar](tools/cfgs/waymo_models/centerpoint_pillar_1x.yaml) | 73.37/72.86 | 65.09/64.62 | 75.35/65.11 | 67.61/58.25 | 67.76/66.22 | 65.25/63.77 |
| [Part-A2-Anchor](tools/cfgs/waymo_models/PartA2.yaml) | 77.05/76.51 | 68.47/67.97 | 75.24/66.87 | 66.18/58.62 | 68.60/67.36 | 66.13/64.93 |
| [VoxelNeXt-2D](tools/cfgs/waymo_models/voxelnext2d_ioubranch.yaml) | 77.94/77.47 |69.68/69.25 |80.24/73.47 |72.23/65.88 |73.33/72.20 |70.66/69.56 |
| [VoxelNeXt](tools/cfgs/waymo_models/voxelnext_ioubranch_large.yaml) | 78.16/77.70 |69.86/69.42 |81.47/76.30 |73.48/68.63 |76.06/74.90 |73.29/72.18 |
| [PV-RCNN (CenterHead)](tools/cfgs/waymo_models/pv_rcnn_with_centerhead_rpn.yaml) | 78.00/77.50 | 69.43/68.98 | 79.21/73.03 | 70.42/64.72 | 71.46/70.27 | 68.95/67.79 |
| [PV-RCNN++](tools/cfgs/waymo_models/pv_rcnn_plusplus.yaml) | 79.10/78.63 | 70.34/69.91 | 80.62/74.62 | 71.86/66.30 | 73.49/72.38 | 70.70/69.62 |
| [PV-RCNN++ (ResNet)](tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet.yaml) | 79.25/78.78 | 70.61/70.18 | 81.83/76.28 | 73.17/68.00 | 73.72/72.66 | 71.21/70.19 |
| [PV-RCNN++ (ResNet, 2 frames)](tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet_2frames.yaml) | 80.17/79.70 | 72.14/71.70 | 83.48/80.42 | 75.54/72.61 | 74.63/73.75 | 72.35/71.50 |
| [MPPNet (4 frames)](docs/guidelines_of_approaches/mppnet.md) | 81.54/81.06 | 74.07/73.61 | 84.56/81.94 | 77.20/74.67 | 77.15/76.50 | 75.01/74.38 |
| [MPPNet (16 frames)](docs/guidelines_of_approaches/mppnet.md) | 82.74/82.28 | 75.41/74.96 | 84.69/82.25 | 77.43/75.06 | 77.28/76.66 | 75.13/74.52 |
| [MPPNet (4 frames)](docs/guidelines_of_approaches/mppnet.md) | 81.54/81.06 | 74.07/73.61 | 84.56/81.94 | 77.20/74.67 | 77.15/76.50 | 75.01/74.38 |
| [MPPNet (16 frames)](docs/guidelines_of_approaches/mppnet.md) | 82.74/82.28 | 75.41/74.96 | 84.69/82.25 | 77.43/75.06 | 77.28/76.66 | 75.13/74.52 |



Expand Down Expand Up @@ -226,8 +227,7 @@ All models are trained with 4 GPUs.

| | mAP | download |
|---------------------------------------------------------|:----:|:--------------------------------------------------------------------------------------------------:|
| [VoxelNeXt](tools/cfgs/argo2_models/cbgs_voxel01_voxelnext.yaml) | 30.0 | [model-30M](https://drive.google.com/file/d/1zr-it1ERJzLQ3a3hP060z_EQqS_RkNaC/view?usp=share_link) |
| [VoxelNeXt-K3](tools/cfgs/argo2_models/cbgs_voxel01_voxelnext_headkernel3.yaml) | 30.7 | [model-45M](https://drive.google.com/file/d/1NrYRsiKbuWyL8jE4SY27IHpFMY9K0o__/view?usp=share_link) |
| [VoxelNeXt](tools/cfgs/argo2_models/cbgs_voxel01_voxelnext.yaml) | 30.5 | [model-32M](https://drive.google.com/file/d/1YP2UOz-yO-cWfYQkIqILEu6bodvCBVrR/view?usp=share_link) |

### Other datasets
Welcome to support other datasets by submitting pull request.
Expand Down
29 changes: 6 additions & 23 deletions pcdet/datasets/argo2/argo2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,39 +488,23 @@ def parse_config():
parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--root_path', type=str, default="/data/argo2/sensor")
parser.add_argument('--output_dir', type=str, default="/data/argo2/processed")
parser.add_argument('--num_process', type=int, default=16)
args = parser.parse_args()
return args

def main(seg_path_list, seg_split_list, info_list, ts2idx, output_dir, save_bin, token, num_process):
for seg_i, seg_path in enumerate(seg_path_list):
if seg_i % num_process != token:
continue
print(f'processing segment: {seg_i}/{len(seg_path_list)}')
split = seg_split_list[seg_i]
process_single_segment(seg_path, split, info_list, ts2idx, output_dir, save_bin)

if __name__ == '__main__':
args = parse_config()
root = args.root_path
output_dir = args.output_dir
num_process = args.num_process
save_bin = True
ts2idx, seg_path_list, seg_split_list = prepare(root)

if num_process > 1:
with mp.Manager() as manager:
info_list = manager.list()
pool = mp.Pool(num_process)
for token in range(num_process):
result = pool.apply_async(main, args=(
seg_path_list, seg_split_list, info_list, ts2idx, output_dir, save_bin, token, num_process))
pool.close()
pool.join()
info_list = list(info_list)
else:
info_list = []
main(seg_path_list, seg_split_list, info_list, ts2idx, output_dir, save_bin, 0, 1)
velodyne_dir = Path(output_dir) / 'training' / 'velodyne'
if not velodyne_dir.exists():
velodyne_dir.mkdir(parents=True, exist_ok=True)

info_list = []
create_argo2_infos(seg_path_list, seg_split_list, info_list, ts2idx, output_dir, save_bin, 0, 1)

assert len(info_list) > 0

Expand Down Expand Up @@ -551,4 +535,3 @@ def main(seg_path_list, seg_split_list, info_list, ts2idx, output_dir, save_bin,

gts = pd.concat(seg_anno_list).reset_index()
gts.to_feather(save_feather_path)

2 changes: 2 additions & 0 deletions pcdet/models/backbones_3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .spconv_backbone_2d import PillarBackBone8x, PillarRes18BackBone8x
from .spconv_backbone_focal import VoxelBackBone8xFocal
from .spconv_backbone_voxelnext import VoxelResBackBone8xVoxelNeXt
from .spconv_backbone_voxelnext2d import VoxelResBackBone8xVoxelNeXt2D
from .spconv_unet import UNetV2

__all__ = {
Expand All @@ -13,6 +14,7 @@
'VoxelResBackBone8x': VoxelResBackBone8x,
'VoxelBackBone8xFocal': VoxelBackBone8xFocal,
'VoxelResBackBone8xVoxelNeXt': VoxelResBackBone8xVoxelNeXt,
'VoxelResBackBone8xVoxelNeXt2D': VoxelResBackBone8xVoxelNeXt2D,
'PillarBackBone8x': PillarBackBone8x,
'PillarRes18BackBone8x': PillarRes18BackBone8x
}
219 changes: 219 additions & 0 deletions pcdet/models/backbones_3d/spconv_backbone_voxelnext2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
from functools import partial
import torch
import torch.nn as nn

from ...utils.spconv_utils import replace_feature, spconv


def post_act_block(in_channels, out_channels, kernel_size, indice_key=None, stride=1, padding=0,
conv_type='subm', norm_fn=None):

if conv_type == 'subm':
conv = spconv.SubMConv2d(in_channels, out_channels, kernel_size, bias=False, indice_key=indice_key)
elif conv_type == 'spconv':
conv = spconv.SparseConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
bias=False, indice_key=indice_key)
elif conv_type == 'inverseconv':
conv = spconv.SparseInverseConv2d(in_channels, out_channels, kernel_size, indice_key=indice_key, bias=False)
else:
raise NotImplementedError

m = spconv.SparseSequential(
conv,
norm_fn(out_channels),
nn.ReLU(),
)

return m


class SparseBasicBlock(spconv.SparseModule):
expansion = 1

def __init__(self, inplanes, planes, stride=1, norm_fn=None, downsample=None, indice_key=None):
super(SparseBasicBlock, self).__init__()

assert norm_fn is not None
bias = norm_fn is not None
self.conv1 = spconv.SubMConv2d(
inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=bias, indice_key=indice_key
)
self.bn1 = norm_fn(planes)
self.relu = nn.ReLU()
self.conv2 = spconv.SubMConv2d(
planes, planes, kernel_size=3, stride=stride, padding=1, bias=bias, indice_key=indice_key
)
self.bn2 = norm_fn(planes)
self.downsample = downsample
self.stride = stride

def forward(self, x):
identity = x

out = self.conv1(x)
out = replace_feature(out, self.bn1(out.features))
out = replace_feature(out, self.relu(out.features))

out = self.conv2(out)
out = replace_feature(out, self.bn2(out.features))

if self.downsample is not None:
identity = self.downsample(x)

out = replace_feature(out, out.features + identity.features)
out = replace_feature(out, self.relu(out.features))

return out


class VoxelResBackBone8xVoxelNeXt2D(nn.Module):
def __init__(self, model_cfg, input_channels, grid_size, **kwargs):
super().__init__()
self.model_cfg = model_cfg
norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01)
self.sparse_shape = grid_size[[1, 0]]

block = post_act_block

spconv_kernel_sizes = model_cfg.get('SPCONV_KERNEL_SIZES', [3, 3, 3, 3])

self.conv1 = spconv.SparseSequential(
SparseBasicBlock(32, 32, norm_fn=norm_fn, indice_key='res1'),
SparseBasicBlock(32, 32, norm_fn=norm_fn, indice_key='res1'),
SparseBasicBlock(32, 32, norm_fn=norm_fn, indice_key='res1'),
)

self.conv2 = spconv.SparseSequential(
# [1600, 1408] <- [800, 704]
block(32, 64, spconv_kernel_sizes[0], norm_fn=norm_fn, stride=2, padding=int(spconv_kernel_sizes[0]//2), indice_key='spconv2', conv_type='spconv'),
SparseBasicBlock(64, 64, norm_fn=norm_fn, indice_key='res2'),
SparseBasicBlock(64, 64, norm_fn=norm_fn, indice_key='res2'),
SparseBasicBlock(64, 64, norm_fn=norm_fn, indice_key='res2'),
SparseBasicBlock(64, 64, norm_fn=norm_fn, indice_key='res2'),
)

self.conv3 = spconv.SparseSequential(
# [800, 704] <- [400, 352]
block(64, 128, spconv_kernel_sizes[1], norm_fn=norm_fn, stride=2, padding=int(spconv_kernel_sizes[1]//2), indice_key='spconv3', conv_type='spconv'),
SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res3'),
SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res3'),
SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res3'),
SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res3'),
SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res3'),
SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res3'),
)

self.conv4 = spconv.SparseSequential(
# [400, 352] <- [200, 176]
block(128, 256, spconv_kernel_sizes[2], norm_fn=norm_fn, stride=2, padding=int(spconv_kernel_sizes[2]//2), indice_key='spconv4', conv_type='spconv'),
SparseBasicBlock(256, 256, norm_fn=norm_fn, indice_key='res4'),
SparseBasicBlock(256, 256, norm_fn=norm_fn, indice_key='res4'),
SparseBasicBlock(256, 256, norm_fn=norm_fn, indice_key='res4'),
)

self.conv5 = spconv.SparseSequential(
# [400, 352] <- [200, 176]
block(256, 256, spconv_kernel_sizes[3], norm_fn=norm_fn, stride=2, padding=int(spconv_kernel_sizes[3]//2), indice_key='spconv5', conv_type='spconv'),
SparseBasicBlock(256, 256, norm_fn=norm_fn, indice_key='res5'),
SparseBasicBlock(256, 256, norm_fn=norm_fn, indice_key='res5'),
SparseBasicBlock(256, 256, norm_fn=norm_fn, indice_key='res5'),
)

self.conv6 = spconv.SparseSequential(
# [400, 352] <- [200, 176]
block(256, 256, spconv_kernel_sizes[3], norm_fn=norm_fn, stride=2, padding=int(spconv_kernel_sizes[3]//2), indice_key='spconv6', conv_type='spconv'),
SparseBasicBlock(256, 256, norm_fn=norm_fn, indice_key='res6'),
SparseBasicBlock(256, 256, norm_fn=norm_fn, indice_key='res6'),
SparseBasicBlock(256, 256, norm_fn=norm_fn, indice_key='res6'),
)

self.conv_out = spconv.SparseSequential(
# [200, 150, 5] -> [200, 150, 2]
spconv.SparseConv2d(256, 256, 3, stride=1, padding=1, bias=False, indice_key='spconv_down2'),
norm_fn(256),
nn.ReLU(),
)

self.shared_conv = spconv.SparseSequential(
spconv.SubMConv2d(256, 256, 3, stride=1, padding=1, bias=True),
nn.BatchNorm1d(256),
nn.ReLU(True),
)

self.num_point_features = 256
self.backbone_channels = {
'x_conv1': 32,
'x_conv2': 64,
'x_conv3': 128,
'x_conv4': 256,
'x_conv5': 256
}
self.forward_ret_dict = {}

def bev_out(self, x_conv):
features_cat = x_conv.features
indices_cat = x_conv.indices

indices_unique, _inv = torch.unique(indices_cat, dim=0, return_inverse=True)
features_unique = features_cat.new_zeros((indices_unique.shape[0], features_cat.shape[1]))
features_unique.index_add_(0, _inv, features_cat)

x_out = spconv.SparseConvTensor(
features=features_unique,
indices=indices_unique,
spatial_shape=x_conv.spatial_shape,
batch_size=x_conv.batch_size
)
return x_out

def forward(self, batch_dict):
pillar_features, pillar_coords = batch_dict['pillar_features'], batch_dict['pillar_coords']
batch_size = batch_dict['batch_size']
input_sp_tensor = spconv.SparseConvTensor(
features=pillar_features,
indices=pillar_coords.int(),
spatial_shape=self.sparse_shape,
batch_size=batch_size
)

x_conv1 = self.conv1(input_sp_tensor)
x_conv2 = self.conv2(x_conv1)
x_conv3 = self.conv3(x_conv2)
x_conv4 = self.conv4(x_conv3)
x_conv5 = self.conv5(x_conv4)
x_conv6 = self.conv6(x_conv5)

x_conv5.indices[:, 1:] *= 2
x_conv6.indices[:, 1:] *= 4
x_conv4 = x_conv4.replace_feature(torch.cat([x_conv4.features, x_conv5.features, x_conv6.features]))
x_conv4.indices = torch.cat([x_conv4.indices, x_conv5.indices, x_conv6.indices])

out = self.bev_out(x_conv4)

out = self.conv_out(out)
out = self.shared_conv(out)

batch_dict.update({
'encoded_spconv_tensor': out,
'encoded_spconv_tensor_stride': 8
})
batch_dict.update({
'multi_scale_2d_features': {
'x_conv1': x_conv1,
'x_conv2': x_conv2,
'x_conv3': x_conv3,
'x_conv4': x_conv4,
'x_conv5': x_conv5,
}
})
batch_dict.update({
'multi_scale_2d_strides': {
'x_conv1': 1,
'x_conv2': 2,
'x_conv3': 4,
'x_conv4': 8,
'x_conv5': 16,
}
})

return batch_dict
Loading

0 comments on commit ad9c25c

Please sign in to comment.