Skip to content

Commit

Permalink
Code release of MPPNet for multi-frame 3D object detection (open-mmla…
Browse files Browse the repository at this point in the history
…b#1089)

* add mppnet in openpcdet

* add mppnet yamls

* add IOU_WEIGHT Flag

* add IOU_WEIGHT Flag

* add IOU_WEIGHT Flag

* add 16 frame effi_test

* add ctrans

* use effi crop

* update mppnet_head.py

* update mppnet_16frame.yaml

* update test.py

* add mppnet_4frame.yaml

* update mppnet_4frame.yaml

* update deted_template

* update det3d_template

* update yaml and clean mppnet head

* rm unused py and yaml

* update yamls

* fixbug of bs 2 eval

* fixbug of bs>1 eval

* update mppnet training code

* update training code

* rm unused file

* rm unused file

* reorganzie code

* reorginaze code

* add transformer.py with paper name

* add transformer.py with paper name

* add transformer.py with paper name

* reorganize code

* reorganize code

* reorganize code

* reorganize code

* reorganize code

* reorganize code

* reorganize code

* reorganize code

* rm unused code

* rm unused code

* format codes

* support save_to_file for WOD to save model predicted results

* fix small bug in generate_single_sample_dict

* support to load pred_boxes from result.pkl to avoid massive small object loading

* bugfixed: train with MPPNet

* bugfixed: remove num_frames in transformer.forward()

* bugfixed: remove num_frames in transformer.forward(), continue

* support to configure train/val result.pkl for ROI_BOXES_PATH for MPPNet

* update MPPNet codes

* bugfixed to support float32/float64 GT database

* update document

Co-authored-by: Shaoshuai Shi <shaoshuaics@gmail.com>
  • Loading branch information
Cedarch and sshaoshuai authored Sep 3, 2022
1 parent aa753ec commit ef7da7d
Show file tree
Hide file tree
Showing 23 changed files with 3,230 additions and 63 deletions.
2 changes: 1 addition & 1 deletion docs/guidelines_of_approaches/mppnet.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
# Will be available soon
# The guideline of MPPNet Will be available soon
36 changes: 30 additions & 6 deletions pcdet/datasets/augmentor/augmentor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from ...utils import box_utils


def random_flip_along_x(gt_boxes, points, return_flip=False):
def random_flip_along_x(gt_boxes, points, return_flip=False, enable=None):
"""
Args:
gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
points: (M, 3 + C)
Returns:
"""
enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5])
if enable is None:
enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5])
if enable:
gt_boxes[:, 1] = -gt_boxes[:, 1]
gt_boxes[:, 6] = -gt_boxes[:, 6]
Expand All @@ -25,14 +26,15 @@ def random_flip_along_x(gt_boxes, points, return_flip=False):
return gt_boxes, points


def random_flip_along_y(gt_boxes, points, return_flip=False):
def random_flip_along_y(gt_boxes, points, return_flip=False, enable=None):
"""
Args:
gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
points: (M, 3 + C)
Returns:
"""
enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5])
if enable is None:
enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5])
if enable:
gt_boxes[:, 0] = -gt_boxes[:, 0]
gt_boxes[:, 6] = -(gt_boxes[:, 6] + np.pi)
Expand All @@ -45,15 +47,16 @@ def random_flip_along_y(gt_boxes, points, return_flip=False):
return gt_boxes, points


def global_rotation(gt_boxes, points, rot_range, return_rot=False):
def global_rotation(gt_boxes, points, rot_range, return_rot=False, noise_rotation=None):
"""
Args:
gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
points: (M, 3 + C),
rot_range: [min, max]
Returns:
"""
noise_rotation = np.random.uniform(rot_range[0], rot_range[1])
if noise_rotation is None:
noise_rotation = np.random.uniform(rot_range[0], rot_range[1])
points = common_utils.rotate_points_along_z(points[np.newaxis, :, :], np.array([noise_rotation]))[0]
gt_boxes[:, 0:3] = common_utils.rotate_points_along_z(gt_boxes[np.newaxis, :, 0:3], np.array([noise_rotation]))[0]
gt_boxes[:, 6] += noise_rotation
Expand Down Expand Up @@ -81,10 +84,31 @@ def global_scaling(gt_boxes, points, scale_range, return_scale=False):
noise_scale = np.random.uniform(scale_range[0], scale_range[1])
points[:, :3] *= noise_scale
gt_boxes[:, :6] *= noise_scale
if gt_boxes.shape[1] > 7:
gt_boxes[:, 7:] *= noise_scale

if return_scale:
return gt_boxes, points, noise_scale
return gt_boxes, points

def global_scaling_with_roi_boxes(gt_boxes, roi_boxes, points, scale_range, return_scale=False):
"""
Args:
gt_boxes: (N, 7), [x, y, z, dx, dy, dz, heading]
points: (M, 3 + C),
scale_range: [min, max]
Returns:
"""
if scale_range[1] - scale_range[0] < 1e-3:
return gt_boxes, points
noise_scale = np.random.uniform(scale_range[0], scale_range[1])
points[:, :3] *= noise_scale
gt_boxes[:, :6] *= noise_scale
roi_boxes[:,:, [0,1,2,3,4,5,7,8]] *= noise_scale
if return_scale:
return gt_boxes,roi_boxes, points, noise_scale
return gt_boxes, roi_boxes, points


def random_image_flip_horizontal(image, depth_map, gt_boxes, calib):
"""
Expand Down
28 changes: 25 additions & 3 deletions pcdet/datasets/augmentor/data_augmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ def random_world_flip(self, data_dict=None, config=None):
gt_boxes, points, return_flip=True
)
data_dict['flip_%s'%cur_axis] = enable
if 'roi_boxes' in data_dict.keys():
num_frame, num_rois,dim = data_dict['roi_boxes'].shape
roi_boxes, _, _ = getattr(augmentor_utils, 'random_flip_along_%s' % cur_axis)(
data_dict['roi_boxes'].reshape(-1,dim), np.zeros([1,3]), return_flip=True, enable=enable
)
data_dict['roi_boxes'] = roi_boxes.reshape(num_frame, num_rois,dim)

data_dict['gt_boxes'] = gt_boxes
data_dict['points'] = points
Expand All @@ -64,6 +70,11 @@ def random_world_rotation(self, data_dict=None, config=None):
gt_boxes, points, noise_rot = augmentor_utils.global_rotation(
data_dict['gt_boxes'], data_dict['points'], rot_range=rot_range, return_rot=True
)
if 'roi_boxes' in data_dict.keys():
num_frame, num_rois,dim = data_dict['roi_boxes'].shape
roi_boxes, _, _ = augmentor_utils.global_rotation(
data_dict['roi_boxes'].reshape(-1, dim), np.zeros([1, 3]), rot_range=rot_range, return_rot=True, noise_rotation=noise_rot)
data_dict['roi_boxes'] = roi_boxes.reshape(num_frame, num_rois,dim)

data_dict['gt_boxes'] = gt_boxes
data_dict['points'] = points
Expand All @@ -73,9 +84,16 @@ def random_world_rotation(self, data_dict=None, config=None):
def random_world_scaling(self, data_dict=None, config=None):
if data_dict is None:
return partial(self.random_world_scaling, config=config)
gt_boxes, points, noise_scale = augmentor_utils.global_scaling(
data_dict['gt_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE'], return_scale=True
)

if 'roi_boxes' in data_dict.keys():
gt_boxes, roi_boxes, points, noise_scale = augmentor_utils.global_scaling_with_roi_boxes(
data_dict['gt_boxes'], data_dict['roi_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE'], return_scale=True
)
data_dict['roi_boxes'] = roi_boxes
else:
gt_boxes, points, noise_scale = augmentor_utils.global_scaling(
data_dict['gt_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE'], return_scale=True
)

data_dict['gt_boxes'] = gt_boxes
data_dict['points'] = points
Expand Down Expand Up @@ -115,6 +133,10 @@ def random_world_translation(self, data_dict=None, config=None):
gt_boxes, points = data_dict['gt_boxes'], data_dict['points']
points[:, :3] += noise_translate
gt_boxes[:, :3] += noise_translate

if 'roi_boxes' in data_dict.keys():
data_dict['roi_boxes'][:, :3] += noise_translate

data_dict['gt_boxes'] = gt_boxes
data_dict['points'] = points
return data_dict
Expand Down
10 changes: 7 additions & 3 deletions pcdet/datasets/augmentor/database_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, root_path, sampler_cfg, class_names, logger=None):
sampler_cfg.DB_DATA_PATH[0] = sampler_cfg.BACKUP_DB_INFO['DB_DATA_PATH']
db_info_path = self.root_path.resolve() / sampler_cfg.DB_INFO_PATH[0]
sampler_cfg.NUM_POINT_FEATURES = sampler_cfg.BACKUP_DB_INFO['NUM_POINT_FEATURES']

with open(str(db_info_path), 'rb') as f:
infos = pickle.load(f)
[self.db_infos[cur_class].extend(infos[cur_class]) for cur_class in class_names]
Expand Down Expand Up @@ -391,10 +391,14 @@ def add_sampled_boxes_to_scene(self, data_dict, sampled_gt_boxes, total_valid_sa
obj_points = copy.deepcopy(gt_database_data[start_offset:end_offset])
else:
file_path = self.root_path / info['path']

obj_points = np.fromfile(str(file_path), dtype=np.float32).reshape(
[-1, self.sampler_cfg.NUM_POINT_FEATURES])
if obj_points.shape[0] != info['num_points_in_gt']:
obj_points = np.fromfile(str(file_path), dtype=np.float64).reshape(-1, self.sampler_cfg.NUM_POINT_FEATURES)

obj_points[:, :3] += info['box3d_lidar'][:3]
assert obj_points.shape[0] == info['num_points_in_gt']
obj_points[:, :3] += info['box3d_lidar'][:3].astype(np.float32)

if self.sampler_cfg.get('USE_ROAD_PLANE', False):
# mv height
Expand All @@ -417,7 +421,7 @@ def add_sampled_boxes_to_scene(self, data_dict, sampled_gt_boxes, total_valid_sa
else:
assert obj_points.shape[-1] == points.shape[-1] + 1
# transform multi-frame GT points to single-frame GT points
min_time = max_time = 0.0
min_time = max_time = 0.0

time_mask = np.logical_and(obj_points[:, -1] < max_time + 1e-6, obj_points[:, -1] > min_time - 1e-6)
obj_points = obj_points[time_mask]
Expand Down
17 changes: 16 additions & 1 deletion pcdet/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def generate_prediction_dicts(self, batch_dict, pred_dicts, class_names, output_
Returns:
"""

def get_template_prediction(num_samples):
box_dim = 9 if self.dataset_cfg.get('TRAIN_WITH_SPEED', False) else 7
ret_dict = {
Expand Down Expand Up @@ -216,6 +216,21 @@ def collate_batch(batch_list, _unused=False):
for k in range(batch_size):
batch_gt_boxes3d[k, :val[k].__len__(), :] = val[k]
ret[key] = batch_gt_boxes3d

elif key in ['roi_boxes']:
max_gt = max([x.shape[1] for x in val])
batch_gt_boxes3d = np.zeros((batch_size, val[0].shape[0], max_gt, val[0].shape[-1]), dtype=np.float32)
for k in range(batch_size):
batch_gt_boxes3d[k,:, :val[k].shape[1], :] = val[k]
ret[key] = batch_gt_boxes3d

elif key in ['roi_scores', 'roi_labels']:
max_gt = max([x.shape[1] for x in val])
batch_gt_boxes3d = np.zeros((batch_size, val[0].shape[0], max_gt), dtype=np.float32)
for k in range(batch_size):
batch_gt_boxes3d[k,:, :val[k].shape[1]] = val[k]
ret[key] = batch_gt_boxes3d

elif key in ['gt_boxes2d']:
max_boxes = 0
max_boxes = max([len(x) for x in val])
Expand Down
Loading

0 comments on commit ef7da7d

Please sign in to comment.