Skip to content

Commit

Permalink
support PointRCNN with PartA2 head
Browse files Browse the repository at this point in the history
  • Loading branch information
sshaoshuai committed Jul 22, 2020
1 parent 4edd264 commit 8ec2a2d
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 1 deletion.
4 changes: 3 additions & 1 deletion pcdet/models/detectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from .PartA2_net import PartA2Net
from .pv_rcnn import PVRCNN
from .pointpillar import PointPillar
from .point_rcnn import PointRCNN

__all__ = {
'Detector3DTemplate': Detector3DTemplate,
'SECONDNet': SECONDNet,
'PartA2Net': PartA2Net,
'PVRCNN': PVRCNN,
'PointPillar': PointPillar
'PointPillar': PointPillar,
'PointRCNN': PointRCNN
}


Expand Down
30 changes: 30 additions & 0 deletions pcdet/models/detectors/point_rcnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from .detector3d_template import Detector3DTemplate


class PointRCNN(Detector3DTemplate):
def __init__(self, model_cfg, num_class, dataset):
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
self.module_list = self.build_networks()

def forward(self, batch_dict):
for cur_module in self.module_list:
batch_dict = cur_module(batch_dict)

if self.training:
loss, tb_dict, disp_dict = self.get_training_loss()

ret_dict = {
'loss': loss
}
return ret_dict, tb_dict, disp_dict
else:
pred_dicts, recall_dicts = self.post_processing(batch_dict)
return pred_dicts, recall_dicts

def get_training_loss(self):
disp_dict = {}
loss_point, tb_dict = self.point_head.get_loss()
loss_rcnn, tb_dict = self.roi_head.get_loss(tb_dict)

loss = loss_point + loss_rcnn
return loss, tb_dict, disp_dict
144 changes: 144 additions & 0 deletions tools/cfgs/kitti_models/pointrcnn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
CLASS_NAMES: ['Car', 'Pedestrian', 'Cyclist']

DATA_CONFIG:
_BASE_CONFIG_: cfgs/dataset_configs/kitti_dataset.yaml

DATA_PROCESSOR:
- NAME: mask_points_and_boxes_outside_range
REMOVE_OUTSIDE_BOXES: True

- NAME: shuffle_points
SHUFFLE_ENABLED: {
'train': True,
'test': False
}

MODEL:
NAME: PointRCNN

BACKBONE_3D:
NAME: PointNet2Backbone
SA_CONFIG:
NPOINTS: [4096, 1024, 256, 64]
RADIUS: [[0.1, 0.5], [0.5, 1.0], [1.0, 2.0], [2.0, 4.0]]
NSAMPLE: [[16, 32], [16, 32], [16, 32], [16, 32]]
MLPS: [[[16, 16, 32], [32, 32, 64]],
[[64, 64, 128], [64, 96, 128]],
[[128, 196, 256], [128, 196, 256]],
[[256, 256, 512], [256, 384, 512]]]
FP_MLPS: [[128, 128], [256, 256], [512, 512], [512, 512]]

POINT_HEAD:
NAME: PointHeadBox
CLS_FC: [256, 256]
REG_FC: [256, 256]
CLASS_AGNOSTIC: False
USE_POINT_FEATURES_BEFORE_FUSION: False
TARGET_CONFIG:
GT_EXTRA_WIDTH: [0.2, 0.2, 0.2]
BOX_CODER: PointResidualCoder
BOX_CODER_CONFIG: {
'use_mean_size': False ,
'mean_size': [
[3.9, 1.6, 1.56],
[0.8, 0.6, 1.73],
[1.76, 0.6, 1.73]
]
}

LOSS_CONFIG:
LOSS_REG: WeightedSmoothL1Loss
LOSS_WEIGHTS: {
'point_cls_weight': 1.0,
'point_box_weight': 1.0,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
}

ROI_HEAD:
NAME: PartA2FCHead
CLASS_AGNOSTIC: True

SHARED_FC: [256, 256, 256]
CLS_FC: [256, 256]
REG_FC: [256, 256]
DP_RATIO: 0.3
DISABLE_PART: True
SEG_MASK_SCORE_THRESH: 0.0

NMS_CONFIG:
TRAIN:
NMS_TYPE: nms_gpu
MULTI_CLASSES_NMS: False
NMS_PRE_MAXSIZE: 9000
NMS_POST_MAXSIZE: 512
NMS_THRESH: 0.8
TEST:
NMS_TYPE: nms_gpu
MULTI_CLASSES_NMS: False
NMS_PRE_MAXSIZE: 1024
NMS_POST_MAXSIZE: 100
NMS_THRESH: 0.7

ROI_AWARE_POOL:
POOL_SIZE: 12
NUM_FEATURES: 128
MAX_POINTS_PER_VOXEL: 128

TARGET_CONFIG:
BOX_CODER: ResidualCoder
ROI_PER_IMAGE: 128
FG_RATIO: 0.5

SAMPLE_ROI_BY_EACH_CLASS: True
CLS_SCORE_TYPE: roi_iou

CLS_FG_THRESH: 0.75
CLS_BG_THRESH: 0.25
CLS_BG_THRESH_LO: 0.1
HARD_BG_RATIO: 0.8

REG_FG_THRESH: 0.65

LOSS_CONFIG:
CLS_LOSS: BinaryCrossEntropy
REG_LOSS: smooth-l1
CORNER_LOSS_REGULARIZATION: True
LOSS_WEIGHTS: {
'rcnn_cls_weight': 1.0,
'rcnn_reg_weight': 1.0,
'rcnn_corner_weight': 1.0,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
}

POST_PROCESSING:
RECALL_THRESH_LIST: [0.3, 0.5, 0.7]
SCORE_THRESH: 0.1
OUTPUT_RAW_SCORE: False

EVAL_METRIC: kitti

NMS_CONFIG:
MULTI_CLASSES_NMS: False
NMS_TYPE: nms_gpu
NMS_THRESH: 0.1
NMS_PRE_MAXSIZE: 4096
NMS_POST_MAXSIZE: 500


OPTIMIZATION:
OPTIMIZER: adam_onecycle
LR: 0.01
WEIGHT_DECAY: 0.01
MOMENTUM: 0.9

MOMS: [0.95, 0.85]
PCT_START: 0.4
DIV_FACTOR: 10
DECAY_STEP_LIST: [35, 45]
LR_DECAY: 0.1
LR_CLIP: 0.0000001

LR_WARMUP: False
WARMUP_EPOCH: 1

GRAD_NORM_CLIP: 10

0 comments on commit 8ec2a2d

Please sign in to comment.