Skip to content

Commit

Permalink
add pp_multihead.yaml config for NuScenes, modify loc weights from 2.…
Browse files Browse the repository at this point in the history
…0 to 0.25
  • Loading branch information
sshaoshuai committed Jul 8, 2020
1 parent b35146e commit 1bc21f7
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 12 deletions.
32 changes: 23 additions & 9 deletions pcdet/models/backbones_2d/base_bev_backbone.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import numpy as np


class BaseBEVBackbone(nn.Module):
Expand Down Expand Up @@ -44,15 +45,28 @@ def __init__(self, model_cfg, input_channels):
])
self.blocks.append(nn.Sequential(*cur_layers))
if len(upsample_strides) > 0:
self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d(
num_filters[idx], num_upsample_filters[idx],
upsample_strides[idx],
stride=upsample_strides[idx], bias=False
),
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
))
stride = upsample_strides[idx]
if stride > 1:
self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d(
num_filters[idx], num_upsample_filters[idx],
upsample_strides[idx],
stride=upsample_strides[idx], bias=False
),
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
))
else:
stride = np.round(1 / stride).astype(np.int)
self.deblocks.append(nn.Sequential(
nn.Conv2d(
num_filters[idx], num_upsample_filters[idx],
stride,
stride=stride, bias=False
),
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
))

c_in = sum(num_upsample_filters)
if len(upsample_strides) > num_levels:
Expand Down
2 changes: 1 addition & 1 deletion tools/cfgs/nuscenes_models/cbgs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ MODEL:
REG_LOSS_TYPE: WeightedL1Loss
LOSS_WEIGHTS: {
'cls_weight': 1.0,
'loc_weight': 2.0,
'loc_weight': 0.25,
'dir_weight': 0.2,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
}
Expand Down
4 changes: 2 additions & 2 deletions tools/cfgs/nuscenes_models/cbgs_1conv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ MODEL:
'anchor_bottom_heights': [-0.225],
'align_center': False,
'feature_map_stride': 8,
'matched_threshold': 0.55,
'matched_threshold': 0.5
'unmatched_threshold': 0.35
},
{
Expand Down Expand Up @@ -197,7 +197,7 @@ MODEL:
REG_LOSS_TYPE: WeightedL1Loss
LOSS_WEIGHTS: {
'cls_weight': 1.0,
'loc_weight': 2.0,
'loc_weight': 0.25,
'dir_weight': 0.2,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
}
Expand Down
254 changes: 254 additions & 0 deletions tools/cfgs/nuscenes_models/pp_multihead.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
CLASS_NAMES: ['car','truck', 'construction_vehicle', 'bus', 'trailer',
'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone']

DATA_CONFIG:
_BASE_CONFIG_: cfgs/dataset_configs/nuscenes_dataset.yaml

POINT_CLOUD_RANGE: [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
DATA_PROCESSOR:
- NAME: mask_points_and_boxes_outside_range
REMOVE_OUTSIDE_BOXES: True

- NAME: shuffle_points
SHUFFLE_ENABLED: {
'train': True,
'test': True
}

- NAME: transform_points_to_voxels
VOXEL_SIZE: [0.2, 0.2, 8.0]
MAX_POINTS_PER_VOXEL: 20
MAX_NUMBER_OF_VOXELS: {
'train': 30000,
'test': 30000
}

MODEL:
NAME: PointPillar

VFE:
NAME: PillarVFE
WITH_DISTANCE: False
USE_ABSLOTE_XYZ: True
USE_NORM: True
NUM_FILTERS: [64]

MAP_TO_BEV:
NAME: PointPillarScatter
NUM_BEV_FEATURES: 64

BACKBONE_2D:
NAME: BaseBEVBackbone
LAYER_NUMS: [3, 5, 5]
LAYER_STRIDES: [2, 2, 2]
NUM_FILTERS: [64, 128, 256]
UPSAMPLE_STRIDES: [0.5, 1, 2]
NUM_UPSAMPLE_FILTERS: [128, 128, 128]

DENSE_HEAD:
NAME: AnchorHeadMulti
CLASS_AGNOSTIC: False

USE_DIRECTION_CLASSIFIER: True
DIR_OFFSET: 0.78539
DIR_LIMIT_OFFSET: 0.0
NUM_DIR_BINS: 2

USE_MULTIHEAD: True
SEPARATE_MULTIHEAD: True
ANCHOR_GENERATOR_CONFIG: [
{
'class_name': car,
'anchor_sizes': [[4.63, 1.97, 1.74]],
'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-0.95],
'align_center': False,
'feature_map_stride': 4,
'matched_threshold': 0.6,
'unmatched_threshold': 0.45
},
{
'class_name': truck,
'anchor_sizes': [[6.93, 2.51, 2.84]],
'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-0.6],
'align_center': False,
'feature_map_stride': 4,
'matched_threshold': 0.55,
'unmatched_threshold': 0.4
},
{
'class_name': construction_vehicle,
'anchor_sizes': [[6.37, 2.85, 3.19]],
'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-0.225],
'align_center': False,
'feature_map_stride': 4,
'matched_threshold': 0.5,
'unmatched_threshold': 0.35
},
{
'class_name': bus,
'anchor_sizes': [[10.5, 2.94, 3.47]],
'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-0.085],
'align_center': False,
'feature_map_stride': 4,
'matched_threshold': 0.55,
'unmatched_threshold': 0.4
},
{
'class_name': trailer,
'anchor_sizes': [[12.29, 2.90, 3.87]],
'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [0.115],
'align_center': False,
'feature_map_stride': 4,
'matched_threshold': 0.5,
'unmatched_threshold': 0.35
},
{
'class_name': barrier,
'anchor_sizes': [[0.50, 2.53, 0.98]],
'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-1.33],
'align_center': False,
'feature_map_stride': 4,
'matched_threshold': 0.55,
'unmatched_threshold': 0.4
},
{
'class_name': motorcycle,
'anchor_sizes': [[2.11, 0.77, 1.47]],
'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-1.085],
'align_center': False,
'feature_map_stride': 4,
'matched_threshold': 0.5,
'unmatched_threshold': 0.3
},
{
'class_name': bicycle,
'anchor_sizes': [[1.70, 0.60, 1.28]],
'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-1.18],
'align_center': False,
'feature_map_stride': 4,
'matched_threshold': 0.5,
'unmatched_threshold': 0.35
},
{
'class_name': pedestrian,
'anchor_sizes': [[0.73, 0.67, 1.77]],
'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-0.935],
'align_center': False,
'feature_map_stride': 4,
'matched_threshold': 0.6,
'unmatched_threshold': 0.4
},
{
'class_name': traffic_cone,
'anchor_sizes': [[0.41, 0.41, 1.07]],
'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-1.285],
'align_center': False,
'feature_map_stride': 4,
'matched_threshold': 0.6,
'unmatched_threshold': 0.4
},
]

SHARED_CONV_NUM_FILTER: 64
RPN_HEAD_CFGS: [
{
'HEAD_CLS_NAME': ['car'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
},
{
'HEAD_CLS_NAME': ['truck', 'construction_vehicle'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
},
{
'HEAD_CLS_NAME': ['bus', 'trailer'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
},
{
'HEAD_CLS_NAME': ['barrier'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
},
{
'HEAD_CLS_NAME': ['motorcycle', 'bicycle'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
},
{
'HEAD_CLS_NAME': ['pedestrian', 'traffic_cone'],
'LAYER_NUMS': [0],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [64],
},
]

TARGET_ASSIGNER_CONFIG:
NAME: AxisAlignedTargetAssigner
POS_FRACTION: -1.0
SAMPLE_SIZE: 512
NORM_BY_NUM_EXAMPLES: False
MATCH_HEIGHT: False
BOX_CODER: ResidualCoder
BOX_CODER_CONFIG: {
'code_size': 9
}


LOSS_CONFIG:
REG_LOSS_TYPE: WeightedL1Loss
LOSS_WEIGHTS: {
'cls_weight': 1.0,
'loc_weight': 0.25,
'dir_weight': 0.2,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
}

POST_PROCESSING:
RECALL_THRESH_LIST: [0.3, 0.5, 0.7]
SCORE_THRESH: 0.1
OUTPUT_RAW_SCORE: False

EVAL_METRIC: kitti

NMS_CONFIG:
MULTI_CLASSES_NMS: False
NMS_TYPE: nms_gpu
NMS_THRESH: 0.2
NMS_PRE_MAXSIZE: 4096
NMS_POST_MAXSIZE: 100


OPTIMIZATION:
OPTIMIZER: adam_onecycle
LR: 0.001
WEIGHT_DECAY: 0.01
MOMENTUM: 0.9

MOMS: [0.95, 0.85]
PCT_START: 0.4
DIV_FACTOR: 10
DECAY_STEP_LIST: [35, 45]
LR_DECAY: 0.1
LR_CLIP: 0.0000001

LR_WARMUP: False
WARMUP_EPOCH: 1

GRAD_NORM_CLIP: 10

0 comments on commit 1bc21f7

Please sign in to comment.