diff --git a/docs/guidelines_of_approaches/mppnet.md b/docs/guidelines_of_approaches/mppnet.md index 14bfb59c1..61184a727 100644 --- a/docs/guidelines_of_approaches/mppnet.md +++ b/docs/guidelines_of_approaches/mppnet.md @@ -1 +1 @@ -# Will be available soon \ No newline at end of file +# The guideline of MPPNet Will be available soon \ No newline at end of file diff --git a/pcdet/datasets/augmentor/augmentor_utils.py b/pcdet/datasets/augmentor/augmentor_utils.py index db0e0c0fb..3c088e33c 100644 --- a/pcdet/datasets/augmentor/augmentor_utils.py +++ b/pcdet/datasets/augmentor/augmentor_utils.py @@ -5,14 +5,15 @@ from ...utils import box_utils -def random_flip_along_x(gt_boxes, points, return_flip=False): +def random_flip_along_x(gt_boxes, points, return_flip=False, enable=None): """ Args: gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]] points: (M, 3 + C) Returns: """ - enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5]) + if enable is None: + enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5]) if enable: gt_boxes[:, 1] = -gt_boxes[:, 1] gt_boxes[:, 6] = -gt_boxes[:, 6] @@ -25,14 +26,15 @@ def random_flip_along_x(gt_boxes, points, return_flip=False): return gt_boxes, points -def random_flip_along_y(gt_boxes, points, return_flip=False): +def random_flip_along_y(gt_boxes, points, return_flip=False, enable=None): """ Args: gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]] points: (M, 3 + C) Returns: """ - enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5]) + if enable is None: + enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5]) if enable: gt_boxes[:, 0] = -gt_boxes[:, 0] gt_boxes[:, 6] = -(gt_boxes[:, 6] + np.pi) @@ -45,7 +47,7 @@ def random_flip_along_y(gt_boxes, points, return_flip=False): return gt_boxes, points -def global_rotation(gt_boxes, points, rot_range, return_rot=False): +def global_rotation(gt_boxes, points, rot_range, return_rot=False, noise_rotation=None): """ Args: gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]] @@ -53,7 +55,8 @@ def global_rotation(gt_boxes, points, rot_range, return_rot=False): rot_range: [min, max] Returns: """ - noise_rotation = np.random.uniform(rot_range[0], rot_range[1]) + if noise_rotation is None: + noise_rotation = np.random.uniform(rot_range[0], rot_range[1]) points = common_utils.rotate_points_along_z(points[np.newaxis, :, :], np.array([noise_rotation]))[0] gt_boxes[:, 0:3] = common_utils.rotate_points_along_z(gt_boxes[np.newaxis, :, 0:3], np.array([noise_rotation]))[0] gt_boxes[:, 6] += noise_rotation @@ -81,10 +84,31 @@ def global_scaling(gt_boxes, points, scale_range, return_scale=False): noise_scale = np.random.uniform(scale_range[0], scale_range[1]) points[:, :3] *= noise_scale gt_boxes[:, :6] *= noise_scale + if gt_boxes.shape[1] > 7: + gt_boxes[:, 7:] *= noise_scale + if return_scale: return gt_boxes, points, noise_scale return gt_boxes, points +def global_scaling_with_roi_boxes(gt_boxes, roi_boxes, points, scale_range, return_scale=False): + """ + Args: + gt_boxes: (N, 7), [x, y, z, dx, dy, dz, heading] + points: (M, 3 + C), + scale_range: [min, max] + Returns: + """ + if scale_range[1] - scale_range[0] < 1e-3: + return gt_boxes, points + noise_scale = np.random.uniform(scale_range[0], scale_range[1]) + points[:, :3] *= noise_scale + gt_boxes[:, :6] *= noise_scale + roi_boxes[:,:, [0,1,2,3,4,5,7,8]] *= noise_scale + if return_scale: + return gt_boxes,roi_boxes, points, noise_scale + return gt_boxes, roi_boxes, points + def random_image_flip_horizontal(image, depth_map, gt_boxes, calib): """ diff --git a/pcdet/datasets/augmentor/data_augmentor.py b/pcdet/datasets/augmentor/data_augmentor.py index ba8be5e83..ac8ed5655 100644 --- a/pcdet/datasets/augmentor/data_augmentor.py +++ b/pcdet/datasets/augmentor/data_augmentor.py @@ -50,6 +50,12 @@ def random_world_flip(self, data_dict=None, config=None): gt_boxes, points, return_flip=True ) data_dict['flip_%s'%cur_axis] = enable + if 'roi_boxes' in data_dict.keys(): + num_frame, num_rois,dim = data_dict['roi_boxes'].shape + roi_boxes, _, _ = getattr(augmentor_utils, 'random_flip_along_%s' % cur_axis)( + data_dict['roi_boxes'].reshape(-1,dim), np.zeros([1,3]), return_flip=True, enable=enable + ) + data_dict['roi_boxes'] = roi_boxes.reshape(num_frame, num_rois,dim) data_dict['gt_boxes'] = gt_boxes data_dict['points'] = points @@ -64,6 +70,11 @@ def random_world_rotation(self, data_dict=None, config=None): gt_boxes, points, noise_rot = augmentor_utils.global_rotation( data_dict['gt_boxes'], data_dict['points'], rot_range=rot_range, return_rot=True ) + if 'roi_boxes' in data_dict.keys(): + num_frame, num_rois,dim = data_dict['roi_boxes'].shape + roi_boxes, _, _ = augmentor_utils.global_rotation( + data_dict['roi_boxes'].reshape(-1, dim), np.zeros([1, 3]), rot_range=rot_range, return_rot=True, noise_rotation=noise_rot) + data_dict['roi_boxes'] = roi_boxes.reshape(num_frame, num_rois,dim) data_dict['gt_boxes'] = gt_boxes data_dict['points'] = points @@ -73,9 +84,16 @@ def random_world_rotation(self, data_dict=None, config=None): def random_world_scaling(self, data_dict=None, config=None): if data_dict is None: return partial(self.random_world_scaling, config=config) - gt_boxes, points, noise_scale = augmentor_utils.global_scaling( - data_dict['gt_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE'], return_scale=True - ) + + if 'roi_boxes' in data_dict.keys(): + gt_boxes, roi_boxes, points, noise_scale = augmentor_utils.global_scaling_with_roi_boxes( + data_dict['gt_boxes'], data_dict['roi_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE'], return_scale=True + ) + data_dict['roi_boxes'] = roi_boxes + else: + gt_boxes, points, noise_scale = augmentor_utils.global_scaling( + data_dict['gt_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE'], return_scale=True + ) data_dict['gt_boxes'] = gt_boxes data_dict['points'] = points @@ -115,6 +133,10 @@ def random_world_translation(self, data_dict=None, config=None): gt_boxes, points = data_dict['gt_boxes'], data_dict['points'] points[:, :3] += noise_translate gt_boxes[:, :3] += noise_translate + + if 'roi_boxes' in data_dict.keys(): + data_dict['roi_boxes'][:, :3] += noise_translate + data_dict['gt_boxes'] = gt_boxes data_dict['points'] = points return data_dict diff --git a/pcdet/datasets/augmentor/database_sampler.py b/pcdet/datasets/augmentor/database_sampler.py index b610db76c..105708a60 100644 --- a/pcdet/datasets/augmentor/database_sampler.py +++ b/pcdet/datasets/augmentor/database_sampler.py @@ -36,7 +36,7 @@ def __init__(self, root_path, sampler_cfg, class_names, logger=None): sampler_cfg.DB_DATA_PATH[0] = sampler_cfg.BACKUP_DB_INFO['DB_DATA_PATH'] db_info_path = self.root_path.resolve() / sampler_cfg.DB_INFO_PATH[0] sampler_cfg.NUM_POINT_FEATURES = sampler_cfg.BACKUP_DB_INFO['NUM_POINT_FEATURES'] - + with open(str(db_info_path), 'rb') as f: infos = pickle.load(f) [self.db_infos[cur_class].extend(infos[cur_class]) for cur_class in class_names] @@ -391,10 +391,14 @@ def add_sampled_boxes_to_scene(self, data_dict, sampled_gt_boxes, total_valid_sa obj_points = copy.deepcopy(gt_database_data[start_offset:end_offset]) else: file_path = self.root_path / info['path'] + obj_points = np.fromfile(str(file_path), dtype=np.float32).reshape( [-1, self.sampler_cfg.NUM_POINT_FEATURES]) + if obj_points.shape[0] != info['num_points_in_gt']: + obj_points = np.fromfile(str(file_path), dtype=np.float64).reshape(-1, self.sampler_cfg.NUM_POINT_FEATURES) - obj_points[:, :3] += info['box3d_lidar'][:3] + assert obj_points.shape[0] == info['num_points_in_gt'] + obj_points[:, :3] += info['box3d_lidar'][:3].astype(np.float32) if self.sampler_cfg.get('USE_ROAD_PLANE', False): # mv height @@ -417,7 +421,7 @@ def add_sampled_boxes_to_scene(self, data_dict, sampled_gt_boxes, total_valid_sa else: assert obj_points.shape[-1] == points.shape[-1] + 1 # transform multi-frame GT points to single-frame GT points - min_time = max_time = 0.0 + min_time = max_time = 0.0 time_mask = np.logical_and(obj_points[:, -1] < max_time + 1e-6, obj_points[:, -1] > min_time - 1e-6) obj_points = obj_points[time_mask] diff --git a/pcdet/datasets/dataset.py b/pcdet/datasets/dataset.py index cd7009d33..6e802b65e 100644 --- a/pcdet/datasets/dataset.py +++ b/pcdet/datasets/dataset.py @@ -72,7 +72,7 @@ def generate_prediction_dicts(self, batch_dict, pred_dicts, class_names, output_ Returns: """ - + def get_template_prediction(num_samples): box_dim = 9 if self.dataset_cfg.get('TRAIN_WITH_SPEED', False) else 7 ret_dict = { @@ -216,6 +216,21 @@ def collate_batch(batch_list, _unused=False): for k in range(batch_size): batch_gt_boxes3d[k, :val[k].__len__(), :] = val[k] ret[key] = batch_gt_boxes3d + + elif key in ['roi_boxes']: + max_gt = max([x.shape[1] for x in val]) + batch_gt_boxes3d = np.zeros((batch_size, val[0].shape[0], max_gt, val[0].shape[-1]), dtype=np.float32) + for k in range(batch_size): + batch_gt_boxes3d[k,:, :val[k].shape[1], :] = val[k] + ret[key] = batch_gt_boxes3d + + elif key in ['roi_scores', 'roi_labels']: + max_gt = max([x.shape[1] for x in val]) + batch_gt_boxes3d = np.zeros((batch_size, val[0].shape[0], max_gt), dtype=np.float32) + for k in range(batch_size): + batch_gt_boxes3d[k,:, :val[k].shape[1]] = val[k] + ret[key] = batch_gt_boxes3d + elif key in ['gt_boxes2d']: max_boxes = 0 max_boxes = max([len(x) for x in val]) diff --git a/pcdet/datasets/waymo/waymo_dataset.py b/pcdet/datasets/waymo/waymo_dataset.py index f8d6b449f..c59c70e09 100644 --- a/pcdet/datasets/waymo/waymo_dataset.py +++ b/pcdet/datasets/waymo/waymo_dataset.py @@ -1,7 +1,7 @@ # OpenPCDet PyTorch Dataloader and Evaluation Tools for Waymo Open Dataset # Reference https://github.com/open-mmlab/OpenPCDet # Written by Shaoshuai Shi, Chaoxu Guo -# All Rights Reserved 2019-2020. +# All Rights Reserved. import os import pickle @@ -38,6 +38,13 @@ def __init__(self, dataset_cfg, class_names, training=True, root_path=None, logg self.shared_memory_file_limit = self.dataset_cfg.get('SHARED_MEMORY_FILE_LIMIT', 0x7FFFFFFF) self.load_data_to_shared_memory() + if self.dataset_cfg.get('USE_PREDBOX', False): + self.pred_boxes_dict = self.load_pred_boxes_to_dict( + pred_boxes_path=self.dataset_cfg.ROI_BOXES_PATH[self.mode] + ) + else: + self.pred_boxes_dict = {} + def set_split(self, split): super().__init__( dataset_cfg=self.dataset_cfg, class_names=self.class_names, training=self.training, @@ -84,6 +91,26 @@ def include_waymo_data(self, mode): seq_name_to_infos = None return seq_name_to_infos + def load_pred_boxes_to_dict(self, pred_boxes_path): + self.logger.info(f'Loading and reorganizing pred_boxes to dict from path: {pred_boxes_path}') + with open(pred_boxes_path, 'rb') as f: + pred_dicts = pickle.load(f) + + pred_boxes_dict = {} + for index, box_dict in enumerate(pred_dicts): + seq_name = box_dict['frame_id'][:-4].replace('training_', '').replace('validation_', '') + sample_idx = int(box_dict['frame_id'][-3:]) + + if seq_name not in pred_boxes_dict: + pred_boxes_dict[seq_name] = {} + + pred_labels = np.array([self.class_names.index(box_dict['name'][k]) + 1 for k in range(box_dict['name'].shape[0])]) + pred_boxes = np.concatenate((box_dict['boxes_lidar'], box_dict['score'][:, np.newaxis], pred_labels[:, np.newaxis]), axis=-1) + pred_boxes_dict[seq_name][sample_idx] = pred_boxes + + self.logger.info(f'Predicted boxes has been loaded, total sequences: {len(pred_boxes_dict)}') + return pred_boxes_dict + def load_data_to_shared_memory(self): self.logger.info(f'Loading training data to shared memory (file limit={self.shared_memory_file_limit})') @@ -176,7 +203,47 @@ def get_lidar(self, sequence_name, sample_idx): points_all[:, 3] = np.tanh(points_all[:, 3]) return points_all - def get_sequence_data(self, info, points, sequence_name, sample_idx, sequence_cfg): + @staticmethod + def transform_prebox_to_current(pred_boxes3d, pose_pre, pose_cur): + """ + + Args: + pred_boxes3d (N, 9 or 11): [x, y, z, dx, dy, dz, raw, score, label] + pose_pre (4, 4): + pose_cur (4, 4): + Returns: + + """ + assert pred_boxes3d.shape[-1] in [9, 11] + pred_boxes3d = pred_boxes3d.copy() + expand_bboxes = np.concatenate([pred_boxes3d[:, :3], np.ones((pred_boxes3d.shape[0], 1))], axis=-1) + + bboxes_global = np.dot(expand_bboxes, pose_pre.T)[:, :3] + expand_bboxes_global = np.concatenate([bboxes_global[:, :3],np.ones((bboxes_global.shape[0], 1))], axis=-1) + bboxes_pre2cur = np.dot(expand_bboxes_global, np.linalg.inv(pose_cur.T))[:, :3] + pred_boxes3d[:, 0:3] = bboxes_pre2cur + + if pred_boxes3d.shape[-1] == 11: + expand_vels = np.concatenate([pred_boxes3d[:, 7:9], np.zeros((pred_boxes3d.shape[0], 1))], axis=-1) + vels_global = np.dot(expand_vels, pose_pre[:3, :3].T) + vels_pre2cur = np.dot(vels_global, np.linalg.inv(pose_cur[:3, :3].T))[:,:2] + pred_boxes3d[:, 7:9] = vels_pre2cur + + pred_boxes3d[:, 6] = pred_boxes3d[..., 6] + np.arctan2(pose_pre[..., 1, 0], pose_pre[..., 0, 0]) + pred_boxes3d[:, 6] = pred_boxes3d[..., 6] - np.arctan2(pose_cur[..., 1, 0], pose_cur[..., 0, 0]) + return pred_boxes3d + + @staticmethod + def reorder_rois_for_refining(pred_bboxes): + num_max_rois = max([len(bbox) for bbox in pred_bboxes]) + num_max_rois = max(1, num_max_rois) # at least one faked rois to avoid error + ordered_bboxes = np.zeros([len(pred_bboxes), num_max_rois, pred_bboxes[0].shape[-1]], dtype=np.float32) + + for bs_idx in range(ordered_bboxes.shape[0]): + ordered_bboxes[bs_idx, :len(pred_bboxes[bs_idx])] = pred_bboxes[bs_idx] + return ordered_bboxes + + def get_sequence_data(self, info, points, sequence_name, sample_idx, sequence_cfg, load_pred_boxes=False): """ Args: info: @@ -191,10 +258,21 @@ def remove_ego_points(points, center_radius=1.0): mask = ~((np.abs(points[:, 0]) < center_radius) & (np.abs(points[:, 1]) < center_radius)) return points[mask] + def load_pred_boxes_from_dict(sequence_name, sample_idx): + """ + boxes: (N, 11) [x, y, z, dx, dy, dn, raw, vx, vy, score, label] + """ + sequence_name = sequence_name.replace('training_', '').replace('validation_', '') + load_boxes = self.pred_boxes_dict[sequence_name][sample_idx] + assert load_boxes.shape[-1] == 11 + load_boxes[:, 7:9] = -0.1 * load_boxes[:, 7:9] # transfer speed to negtive motion from t to t-1 + return load_boxes + pose_cur = info['pose'].reshape((4, 4)) num_pts_cur = points.shape[0] - sample_idx_pre_list = np.clip(sample_idx + np.arange( - sequence_cfg.SAMPLE_OFFSET[0], sequence_cfg.SAMPLE_OFFSET[1]), 0, 0x7FFFFFFF) + sample_idx_pre_list = np.clip(sample_idx + np.arange(sequence_cfg.SAMPLE_OFFSET[0], sequence_cfg.SAMPLE_OFFSET[1]), 0, 0x7FFFFFFF) + sample_idx_pre_list = sample_idx_pre_list[::-1] + if sequence_cfg.get('ONEHOT_TIMESTAMP', False): onehot_cur = np.zeros((points.shape[0], len(sample_idx_pre_list) + 1)).astype(points.dtype) onehot_cur[:, 0] = 1 @@ -204,34 +282,54 @@ def remove_ego_points(points, center_radius=1.0): points_pre_all = [] num_points_pre = [] + pose_all = [pose_cur] + pred_boxes_all = [] + if load_pred_boxes: + pred_boxes = load_pred_boxes_from_dict(sequence_name, sample_idx) + pred_boxes_all.append(pred_boxes) + sequence_info = self.seq_name_to_infos[sequence_name] - for i, sample_idx_pre in enumerate(sample_idx_pre_list): - if sample_idx == sample_idx_pre: - continue + for idx, sample_idx_pre in enumerate(sample_idx_pre_list): points_pre = self.get_lidar(sequence_name, sample_idx_pre) pose_pre = sequence_info[sample_idx_pre]['pose'].reshape((4, 4)) expand_points_pre = np.concatenate([points_pre[:, :3], np.ones((points_pre.shape[0], 1))], axis=-1) points_pre_global = np.dot(expand_points_pre, pose_pre.T)[:, :3] - expand_points_pre_global = np.concatenate([points_pre_global, - np.ones((points_pre_global.shape[0], 1))], axis=-1) + expand_points_pre_global = np.concatenate([points_pre_global, np.ones((points_pre_global.shape[0], 1))], axis=-1) points_pre2cur = np.dot(expand_points_pre_global, np.linalg.inv(pose_cur.T))[:, :3] points_pre = np.concatenate([points_pre2cur, points_pre[:, 3:]], axis=-1) if sequence_cfg.get('ONEHOT_TIMESTAMP', False): onehot_vector = np.zeros((points_pre.shape[0], len(sample_idx_pre_list) + 1)) - onehot_vector[:, i + 1] = 1 + onehot_vector[:, idx + 1] = 1 points_pre = np.hstack([points_pre, onehot_vector]) else: # add timestamp - points_pre = np.hstack([points_pre, 0.1 * (sample_idx - sample_idx_pre) - * np.ones((points_pre.shape[0], 1)).astype(points_pre.dtype)]) # one frame 0.1s + points_pre = np.hstack([points_pre, 0.1 * (sample_idx - sample_idx_pre) * np.ones((points_pre.shape[0], 1)).astype(points_pre.dtype)]) # one frame 0.1s points_pre = remove_ego_points(points_pre, 1.0) points_pre_all.append(points_pre) num_points_pre.append(points_pre.shape[0]) + pose_all.append(pose_pre) + + if load_pred_boxes: + pose_pre = sequence_info[sample_idx_pre]['pose'].reshape((4, 4)) + pred_boxes = load_pred_boxes_from_dict(sequence_name, sample_idx_pre) + pred_boxes = self.transform_prebox_to_current(pred_boxes, pose_pre, pose_cur) + pred_boxes_all.append(pred_boxes) + points = np.concatenate([points] + points_pre_all, axis=0).astype(np.float32) num_points_all = np.array([num_pts_cur] + num_points_pre).astype(np.int32) - return points, num_points_all, sample_idx_pre_list + poses = np.concatenate(pose_all, axis=0).astype(np.float32) + + if load_pred_boxes: + temp_pred_boxes = self.reorder_rois_for_refining(pred_boxes_all) + pred_boxes = temp_pred_boxes[:, :, 0:9] + pred_scores = temp_pred_boxes[:, :, 9] + pred_labels = temp_pred_boxes[:, :, 10] + else: + pred_boxes = pred_scores = pred_labels = None + + return points, num_points_all, sample_idx_pre_list, poses, pred_boxes, pred_scores, pred_labels def __len__(self): if self._merge_all_iters_to_one_epoch: @@ -247,7 +345,9 @@ def __getitem__(self, index): pc_info = info['point_cloud'] sequence_name = pc_info['lidar_sequence'] sample_idx = pc_info['sample_idx'] - + input_dict = { + 'sample_idx': sample_idx + } if self.use_shared_memory and index < self.shared_memory_file_limit: sa_key = f'{sequence_name}___{sample_idx}' points = SharedArray.attach(f"shm://{sa_key}").copy() @@ -255,14 +355,22 @@ def __getitem__(self, index): points = self.get_lidar(sequence_name, sample_idx) if self.dataset_cfg.get('SEQUENCE_CONFIG', None) is not None and self.dataset_cfg.SEQUENCE_CONFIG.ENABLED: - points, num_points_all, sample_idx_pre_list = self.get_sequence_data( - info, points, sequence_name, sample_idx, self.dataset_cfg.SEQUENCE_CONFIG + points, num_points_all, sample_idx_pre_list, poses, pred_boxes, pred_scores, pred_labels = self.get_sequence_data( + info, points, sequence_name, sample_idx, self.dataset_cfg.SEQUENCE_CONFIG, + load_pred_boxes=self.dataset_cfg.get('USE_PREDBOX', False) ) - - input_dict = { + input_dict['poses'] = poses + if self.dataset_cfg.get('USE_PREDBOX', False): + input_dict.update({ + 'roi_boxes': pred_boxes, + 'roi_scores': pred_scores, + 'roi_labels': pred_labels, + }) + + input_dict.update({ 'points': points, 'frame_id': info['frame_id'], - } + }) if 'annos' in info: annos = info['annos'] @@ -448,7 +556,7 @@ def create_groundtruth_database(self, info_path, save_path, used_classes=None, s stacked_gt_points = np.concatenate(stacked_gt_points, axis=0) np.save(db_data_save_path, stacked_gt_points) - def create_gt_database_of_single_scene(self, info_with_idx, database_save_path=None, use_sequence_data=False, used_classes=None, + def create_gt_database_of_single_scene(self, info_with_idx, database_save_path=None, use_sequence_data=False, used_classes=None, total_samples=0, use_cuda=False, crop_gt_with_tail=False): info, info_idx = info_with_idx print('gt_database sample: %d/%d' % (info_idx, total_samples)) @@ -484,7 +592,7 @@ def create_gt_database_of_single_scene(self, info_with_idx, database_save_path=N num_obj = gt_boxes.shape[0] if num_obj == 0: return {} - + if use_sequence_data and crop_gt_with_tail: assert gt_boxes.shape[1] == 9 speed = gt_boxes[:, 7:9] @@ -496,11 +604,11 @@ def create_gt_database_of_single_scene(self, info_with_idx, database_save_path=N latest_center = gt_boxes[:, 0:2] oldest_center = latest_center - speed * (num_frames - 1) * 0.1 new_center = (latest_center + oldest_center) * 0.5 - new_length = gt_boxes[:, 3] + np.linalg.norm(latest_center - oldest_center, axis=-1) + new_length = gt_boxes[:, 3] + np.linalg.norm(latest_center - oldest_center, axis=-1) gt_boxes_crop = gt_boxes.copy() gt_boxes_crop[:, 0:2] = new_center - gt_boxes_crop[:, 3] = new_length - + gt_boxes_crop[:, 3] = new_length + else: gt_boxes_crop = gt_boxes @@ -534,7 +642,7 @@ def create_gt_database_of_single_scene(self, info_with_idx, database_save_path=N db_path = str(filepath.relative_to(self.root_path)) # gt_database/xxxxx.bin db_info = {'name': names[i], 'path': db_path, 'sequence_name': sequence_name, 'sample_idx': sample_idx, 'gt_idx': i, 'box3d_lidar': gt_boxes[i], - 'num_points_in_gt': gt_points.shape[0], 'difficulty': difficulty[i], + 'num_points_in_gt': gt_points.shape[0], 'difficulty': difficulty[i], 'box3d_crop': gt_boxes_crop[i]} if names[i] in all_db_infos: @@ -543,7 +651,7 @@ def create_gt_database_of_single_scene(self, info_with_idx, database_save_path=N all_db_infos[names[i]] = [db_info] return all_db_infos - def create_groundtruth_database_parallel(self, info_path, save_path, used_classes=None, split='train', sampled_interval=10, + def create_groundtruth_database_parallel(self, info_path, save_path, used_classes=None, split='train', sampled_interval=10, processed_data_tag=None, num_workers=16, crop_gt_with_tail=False): use_sequence_data = self.dataset_cfg.get('SEQUENCE_CONFIG', None) is not None and self.dataset_cfg.SEQUENCE_CONFIG.ENABLED if use_sequence_data: @@ -565,7 +673,7 @@ def create_groundtruth_database_parallel(self, info_path, save_path, used_classe create_gt_database_of_single_scene = partial( self.create_gt_database_of_single_scene, use_sequence_data=use_sequence_data, database_save_path=database_save_path, - used_classes=used_classes, total_samples=len(infos), use_cuda=False, + used_classes=used_classes, total_samples=len(infos), use_cuda=False, crop_gt_with_tail=crop_gt_with_tail ) # create_gt_database_of_single_scene((infos[300], 0)) diff --git a/pcdet/datasets/waymo/waymo_eval.py b/pcdet/datasets/waymo/waymo_eval.py index 84f0ae93e..20c32823b 100644 --- a/pcdet/datasets/waymo/waymo_eval.py +++ b/pcdet/datasets/waymo/waymo_eval.py @@ -68,7 +68,7 @@ def boxes3d_kitti_fakelidar_to_lidar(boxes3d_lidar): num_boxes = len(info['boxes_lidar']) difficulty.append([0] * num_boxes) score.append(info['score']) - boxes3d.append(np.array(info['boxes_lidar'])) + boxes3d.append(np.array(info['boxes_lidar'][:, :7])) box_name = info['name'] if boxes3d[-1].shape[-1] == 9: boxes3d[-1] = boxes3d[-1][:, 0:7] diff --git a/pcdet/models/backbones_3d/__init__.py b/pcdet/models/backbones_3d/__init__.py index 7beaddd38..f58b4f9cc 100644 --- a/pcdet/models/backbones_3d/__init__.py +++ b/pcdet/models/backbones_3d/__init__.py @@ -9,5 +9,5 @@ 'PointNet2Backbone': PointNet2Backbone, 'PointNet2MSG': PointNet2MSG, 'VoxelResBackBone8x': VoxelResBackBone8x, - 'VoxelBackBone8xFocal': VoxelBackBone8xFocal, + 'VoxelBackBone8xFocal': VoxelBackBone8xFocal } diff --git a/pcdet/models/detectors/__init__.py b/pcdet/models/detectors/__init__.py index 09b24f35a..2e9d949a6 100644 --- a/pcdet/models/detectors/__init__.py +++ b/pcdet/models/detectors/__init__.py @@ -9,6 +9,8 @@ from .voxel_rcnn import VoxelRCNN from .centerpoint import CenterPoint from .pv_rcnn_plusplus import PVRCNNPlusPlus +from .mppnet import MPPNet +from .mppnet_e2e import MPPNetE2E __all__ = { 'Detector3DTemplate': Detector3DTemplate, @@ -21,7 +23,9 @@ 'CaDDN': CaDDN, 'VoxelRCNN': VoxelRCNN, 'CenterPoint': CenterPoint, - 'PVRCNNPlusPlus': PVRCNNPlusPlus + 'PVRCNNPlusPlus': PVRCNNPlusPlus, + 'MPPNet': MPPNet, + 'MPPNetE2E': MPPNetE2E } diff --git a/pcdet/models/detectors/detector3d_template.py b/pcdet/models/detectors/detector3d_template.py index 456b1dee5..862ceeb19 100644 --- a/pcdet/models/detectors/detector3d_template.py +++ b/pcdet/models/detectors/detector3d_template.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn - +import numpy as np from ...ops.iou3d_nms import iou3d_nms_utils from ...utils.spconv_utils import find_all_spconv_keys from .. import backbones_2d, backbones_3d, dense_heads, roi_heads @@ -163,7 +163,7 @@ def build_roi_head(self, model_info_dict): point_head_module = roi_heads.__all__[self.model_cfg.ROI_HEAD.NAME]( model_cfg=self.model_cfg.ROI_HEAD, input_channels=model_info_dict['num_point_features'], - backbone_channels=model_info_dict['backbone_channels'], + backbone_channels= model_info_dict.get('backbone_channels', None), point_cloud_range=model_info_dict['point_cloud_range'], voxel_size=model_info_dict['voxel_size'], num_class=self.num_class if not self.model_cfg.ROI_HEAD.CLASS_AGNOSTIC else 1, @@ -206,7 +206,7 @@ def post_processing(self, batch_dict): box_preds = batch_dict['batch_box_preds'][batch_mask] src_box_preds = box_preds - + if not isinstance(batch_dict['batch_cls_preds'], list): cls_preds = batch_dict['batch_cls_preds'][batch_mask] @@ -253,7 +253,7 @@ def post_processing(self, batch_dict): label_key = 'roi_labels' if 'roi_labels' in batch_dict else 'batch_pred_labels' label_preds = batch_dict[label_key][index] else: - label_preds = label_preds + 1 + label_preds = label_preds + 1 selected, selected_scores = model_nms_utils.class_agnostic_nms( box_scores=cls_preds, box_preds=box_preds, nms_config=post_process_cfg.NMS_CONFIG, @@ -267,12 +267,12 @@ def post_processing(self, batch_dict): final_scores = selected_scores final_labels = label_preds[selected] final_boxes = box_preds[selected] - + recall_dict = self.generate_recall_record( box_preds=final_boxes if 'rois' not in batch_dict else src_box_preds, recall_dict=recall_dict, batch_index=index, data_dict=batch_dict, thresh_list=post_process_cfg.RECALL_THRESH_LIST - ) + ) record_dict = { 'pred_boxes': final_boxes, @@ -358,7 +358,7 @@ def _load_state_dict(self, model_state_disk, *, strict=True): self.load_state_dict(state_dict) return state_dict, update_model_state - def load_params_from_file(self, filename, logger, to_cpu=False): + def load_params_from_file(self, filename, logger, to_cpu=False, pre_trained_path=None): if not os.path.isfile(filename): raise FileNotFoundError @@ -366,7 +366,11 @@ def load_params_from_file(self, filename, logger, to_cpu=False): loc_type = torch.device('cpu') if to_cpu else None checkpoint = torch.load(filename, map_location=loc_type) model_state_disk = checkpoint['model_state'] - + if not pre_trained_path is None: + pretrain_checkpoint = torch.load(pre_trained_path, map_location=loc_type) + pretrain_model_state_disk = pretrain_checkpoint['model_state'] + model_state_disk.update(pretrain_model_state_disk) + version = checkpoint.get("version", None) if version is not None: logger.info('==> Checkpoint trained from version: %s' % version) diff --git a/pcdet/models/detectors/mppnet.py b/pcdet/models/detectors/mppnet.py new file mode 100644 index 000000000..10eeb6873 --- /dev/null +++ b/pcdet/models/detectors/mppnet.py @@ -0,0 +1,181 @@ +import torch +from .detector3d_template import Detector3DTemplate +from pcdet.ops.iou3d_nms import iou3d_nms_utils +import os +import numpy as np +import time +from ...utils import common_utils +from ..model_utils import model_nms_utils +from pcdet.datasets.augmentor import augmentor_utils, database_sampler + + +class MPPNet(Detector3DTemplate): + def __init__(self, model_cfg, num_class, dataset): + super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset) + self.module_list = self.build_networks() + + def forward(self, batch_dict): + batch_dict['proposals_list'] = batch_dict['roi_boxes'] + for cur_module in self.module_list[:]: + batch_dict = cur_module(batch_dict) + + if self.training: + loss, tb_dict, disp_dict = self.get_training_loss() + + ret_dict = { + 'loss': loss + } + + return ret_dict, tb_dict, disp_dict + else: + + pred_dicts, recall_dicts = self.post_processing(batch_dict) + + return pred_dicts, recall_dicts + + def get_training_loss(self): + disp_dict = {} + tb_dict ={} + loss_rcnn, tb_dict = self.roi_head.get_loss(tb_dict) + loss = loss_rcnn + + return loss, tb_dict, disp_dict + + def post_processing(self, batch_dict): + """ + Args: + batch_dict: + batch_size: + batch_cls_preds: (B, num_boxes, num_classes | 1) or (N1+N2+..., num_classes | 1) + or [(B, num_boxes, num_class1), (B, num_boxes, num_class2) ...] + multihead_label_mapping: [(num_class1), (num_class2), ...] + batch_box_preds: (B, num_boxes, 7+C) or (N1+N2+..., 7+C) + cls_preds_normalized: indicate whether batch_cls_preds is normalized + batch_index: optional (N1+N2+...) + has_class_labels: True/False + roi_labels: (B, num_rois) 1 .. num_classes + batch_pred_labels: (B, num_boxes, 1) + Returns: + + """ + post_process_cfg = self.model_cfg.POST_PROCESSING + batch_size = batch_dict['batch_size'] + recall_dict = {} + pred_dicts = [] + for index in range(batch_size): + if batch_dict.get('batch_index', None) is not None: + assert batch_dict['batch_box_preds'].shape.__len__() == 2 + batch_mask = (batch_dict['batch_index'] == index) + else: + assert batch_dict['batch_box_preds'].shape.__len__() == 3 + batch_mask = index + + box_preds = batch_dict['batch_box_preds'][batch_mask] + src_box_preds = box_preds + if not isinstance(batch_dict['batch_cls_preds'], list): + cls_preds = batch_dict['batch_cls_preds'][batch_mask] + + src_cls_preds = cls_preds + assert cls_preds.shape[1] in [1, self.num_class] + + if not batch_dict['cls_preds_normalized']: + cls_preds = torch.sigmoid(cls_preds) + else: + cls_preds = [x[batch_mask] for x in batch_dict['batch_cls_preds']] + src_cls_preds = cls_preds + if not batch_dict['cls_preds_normalized']: + cls_preds = [torch.sigmoid(x) for x in cls_preds] + + if post_process_cfg.NMS_CONFIG.MULTI_CLASSES_NMS: + if not isinstance(cls_preds, list): + cls_preds = [cls_preds] + multihead_label_mapping = [torch.arange(1, self.num_class, device=cls_preds[0].device)] + else: + multihead_label_mapping = batch_dict['multihead_label_mapping'] + + cur_start_idx = 0 + pred_scores, pred_labels, pred_boxes = [], [], [] + for cur_cls_preds, cur_label_mapping in zip(cls_preds, multihead_label_mapping): + assert cur_cls_preds.shape[1] == len(cur_label_mapping) + cur_box_preds = box_preds[cur_start_idx: cur_start_idx + cur_cls_preds.shape[0]] + cur_pred_scores, cur_pred_labels, cur_pred_boxes = model_nms_utils.multi_classes_nms( + cls_scores=cur_cls_preds, box_preds=cur_box_preds, + nms_config=post_process_cfg.NMS_CONFIG, + score_thresh=post_process_cfg.SCORE_THRESH + ) + cur_pred_labels = cur_label_mapping[cur_pred_labels] + pred_scores.append(cur_pred_scores) + pred_labels.append(cur_pred_labels) + pred_boxes.append(cur_pred_boxes) + cur_start_idx += cur_cls_preds.shape[0] + + final_scores = torch.cat(pred_scores, dim=0) + final_labels = torch.cat(pred_labels, dim=0) + final_boxes = torch.cat(pred_boxes, dim=0) + else: + try: + cls_preds, label_preds = torch.max(cls_preds, dim=-1) + except: + record_dict = { + 'pred_boxes': torch.tensor([]), + 'pred_scores': torch.tensor([]), + 'pred_labels': torch.tensor([]) + } + pred_dicts.append(record_dict) + continue + + if batch_dict.get('has_class_labels', False): + label_key = 'roi_labels' if 'roi_labels' in batch_dict else 'batch_pred_labels' + label_preds = batch_dict[label_key][index] + else: + label_preds = label_preds + 1 + + selected, selected_scores = model_nms_utils.class_agnostic_nms( + box_scores=cls_preds, box_preds=box_preds, + nms_config=post_process_cfg.NMS_CONFIG, + score_thresh=post_process_cfg.SCORE_THRESH + ) + + if post_process_cfg.OUTPUT_RAW_SCORE: + max_cls_preds, _ = torch.max(src_cls_preds, dim=-1) + selected_scores = max_cls_preds[selected] + + final_scores = selected_scores + final_labels = label_preds[selected] + final_boxes = box_preds[selected] + + ######### Car DONOT Using NMS ###### + if post_process_cfg.get('NOT_APPLY_NMS_FOR_VEL',False): + + pedcyc_mask = final_labels !=1 + final_scores_pedcyc = final_scores[pedcyc_mask] + final_labels_pedcyc = final_labels[pedcyc_mask] + final_boxes_pedcyc = final_boxes[pedcyc_mask] + + car_mask = (label_preds==1) & (cls_preds > post_process_cfg.SCORE_THRESH) + final_scores_car = cls_preds[car_mask] + final_labels_car = label_preds[car_mask] + final_boxes_car = box_preds[car_mask] + + final_scores = torch.cat([final_scores_car,final_scores_pedcyc],0) + final_labels = torch.cat([final_labels_car,final_labels_pedcyc],0) + final_boxes = torch.cat([final_boxes_car,final_boxes_pedcyc],0) + + ######### Car DONOT Using NMS ###### + + recall_dict = self.generate_recall_record( + box_preds=final_boxes if 'rois' not in batch_dict else src_box_preds, + recall_dict=recall_dict, batch_index=index, data_dict=batch_dict, + thresh_list=post_process_cfg.RECALL_THRESH_LIST + ) + + + record_dict = { + 'pred_boxes': final_boxes[:,:7], + 'pred_scores': final_scores, + 'pred_labels': final_labels + } + pred_dicts.append(record_dict) + + return pred_dicts, recall_dict + diff --git a/pcdet/models/detectors/mppnet_e2e.py b/pcdet/models/detectors/mppnet_e2e.py new file mode 100644 index 000000000..7561c9b9a --- /dev/null +++ b/pcdet/models/detectors/mppnet_e2e.py @@ -0,0 +1,222 @@ +import torch +import os +import numpy as np +import copy +from ...utils import common_utils +from ..model_utils import model_nms_utils +from .detector3d_template import Detector3DTemplate +from pcdet.ops.iou3d_nms import iou3d_nms_utils +from pcdet.datasets.augmentor import augmentor_utils, database_sampler + + +class MPPNetE2E(Detector3DTemplate): + def __init__(self, model_cfg, num_class, dataset): + super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset) + self.module_list = self.build_networks() + + self.module_topology = [ + 'vfe', 'backbone_3d', 'map_to_bev_module', + 'backbone_2d', 'dense_head','roi_head' + ] + + self.num_frames = self.model_cfg.ROI_HEAD.Transformer.num_frames + + def reset_memorybank(self): + self.memory_rois = None + self.memory_labels = None + self.memory_scores = None + self.memory_feature = None + + def forward(self, batch_dict): + + if batch_dict['sample_idx'][0] ==0: + self.reset_memorybank() + batch_dict['memory_bank'] = {} + else: + batch_dict['memory_bank'] = {'feature_bank':self.memory_feature} + + if self.num_frames ==16: + batch_dict['points_backup'] = batch_dict['points'].clone() + time_mask = batch_dict['points'][:,-1] < 0.31 # centerpoint RPN only use 4frames + batch_dict['points'] = batch_dict['points'][time_mask] + + for idx, cur_module in enumerate(self.module_list): + batch_dict = cur_module(batch_dict) + if self.module_topology[idx] == 'dense_head': + + if self.memory_rois is None: + self.memory_rois = [batch_dict['rois']]*self.num_frames + self.memory_labels = [batch_dict['roi_labels'][:,:,None]]*self.num_frames + self.memory_scores = [batch_dict['roi_scores'][:,:,None]]*self.num_frames + else: + self.memory_rois.pop() + self.memory_rois.insert(0,batch_dict['rois']) + self.memory_labels.pop() + self.memory_labels.insert(0,batch_dict['roi_labels'][:,:,None]) + self.memory_scores.pop() + self.memory_scores.insert(0,batch_dict['roi_scores'][:,:,None]) + + + batch_dict['memory_bank'].update({'rois': self.memory_rois, + 'roi_labels': self.memory_labels, + 'roi_scores': self.memory_scores}) + + + if self.module_topology[idx] == 'roi_head': + if self.memory_feature is None: + self.memory_feature = [batch_dict['geometory_feature_memory'][:,:64]]*self.num_frames + + else: + self.memory_feature.pop() + self.memory_feature.insert(0,batch_dict['geometory_feature_memory'][:,:64]) + + + if self.training: + loss, tb_dict, disp_dict = self.get_training_loss() + + ret_dict = { + 'loss': loss + } + return ret_dict, tb_dict, disp_dict + else: + pred_dicts, recall_dicts = self.post_processing(batch_dict) + + return pred_dicts, recall_dicts + + + def get_training_loss(self): + disp_dict = {} + + loss_rpn, tb_dict = self.dense_head.get_loss() + tb_dict = { + 'loss_rpn': loss_rpn.item(), + **tb_dict + } + + loss = loss_rpn + return loss, tb_dict, disp_dict + + + def post_processing(self, batch_dict): + + post_process_cfg = self.model_cfg.POST_PROCESSING + batch_size = batch_dict['batch_size'] + recall_dict = {} + pred_dicts = [] + for index in range(batch_size): + if batch_dict.get('batch_index', None) is not None: + assert batch_dict['batch_box_preds'].shape.__len__() == 2 + batch_mask = (batch_dict['batch_index'] == index) + else: + assert batch_dict['batch_box_preds'].shape.__len__() == 3 + batch_mask = index + + box_preds = batch_dict['batch_box_preds'][batch_mask] + src_box_preds = box_preds + if not isinstance(batch_dict['batch_cls_preds'], list): + cls_preds = batch_dict['batch_cls_preds'][batch_mask] + + src_cls_preds = cls_preds + assert cls_preds.shape[1] in [1, self.num_class] + + if not batch_dict['cls_preds_normalized']: + cls_preds = torch.sigmoid(cls_preds) + else: + cls_preds = [x[batch_mask] for x in batch_dict['batch_cls_preds']] + src_cls_preds = cls_preds + if not batch_dict['cls_preds_normalized']: + cls_preds = [torch.sigmoid(x) for x in cls_preds] + + if post_process_cfg.NMS_CONFIG.MULTI_CLASSES_NMS: + if not isinstance(cls_preds, list): + cls_preds = [cls_preds] + multihead_label_mapping = [torch.arange(1, self.num_class, device=cls_preds[0].device)] + else: + multihead_label_mapping = batch_dict['multihead_label_mapping'] + + cur_start_idx = 0 + pred_scores, pred_labels, pred_boxes = [], [], [] + for cur_cls_preds, cur_label_mapping in zip(cls_preds, multihead_label_mapping): + assert cur_cls_preds.shape[1] == len(cur_label_mapping) + cur_box_preds = box_preds[cur_start_idx: cur_start_idx + cur_cls_preds.shape[0]] + cur_pred_scores, cur_pred_labels, cur_pred_boxes = model_nms_utils.multi_classes_nms( + cls_scores=cur_cls_preds, box_preds=cur_box_preds, + nms_config=post_process_cfg.NMS_CONFIG, + score_thresh=post_process_cfg.SCORE_THRESH + ) + cur_pred_labels = cur_label_mapping[cur_pred_labels] + pred_scores.append(cur_pred_scores) + pred_labels.append(cur_pred_labels) + pred_boxes.append(cur_pred_boxes) + cur_start_idx += cur_cls_preds.shape[0] + + final_scores = torch.cat(pred_scores, dim=0) + final_labels = torch.cat(pred_labels, dim=0) + final_boxes = torch.cat(pred_boxes, dim=0) + else: + try: + cls_preds, label_preds = torch.max(cls_preds, dim=-1) + except: + record_dict = { + 'pred_boxes': torch.tensor([]), + 'pred_scores': torch.tensor([]), + 'pred_labels': torch.tensor([]) + } + pred_dicts.append(record_dict) + continue + + if batch_dict.get('has_class_labels', False): + label_key = 'roi_labels' if 'roi_labels' in batch_dict else 'batch_pred_labels' + label_preds = batch_dict[label_key][index] + else: + label_preds = label_preds + 1 + + selected, selected_scores = model_nms_utils.class_agnostic_nms( + box_scores=cls_preds, box_preds=box_preds, + nms_config=post_process_cfg.NMS_CONFIG, + score_thresh=post_process_cfg.SCORE_THRESH + ) + + if post_process_cfg.OUTPUT_RAW_SCORE: + max_cls_preds, _ = torch.max(src_cls_preds, dim=-1) + selected_scores = max_cls_preds[selected] + + final_scores = selected_scores + final_labels = label_preds[selected] + final_boxes = box_preds[selected] + + ######### Car DONOT Using NMS ###### + if post_process_cfg.get('NOT_APPLY_NMS_FOR_VEL',False): + + pedcyc_mask = final_labels !=1 + final_scores_pedcyc = final_scores[pedcyc_mask] + final_labels_pedcyc = final_labels[pedcyc_mask] + final_boxes_pedcyc = final_boxes[pedcyc_mask] + + car_mask = (label_preds==1) & (cls_preds > post_process_cfg.SCORE_THRESH) + final_scores_car = cls_preds[car_mask] + final_labels_car = label_preds[car_mask] + final_boxes_car = box_preds[car_mask] + + final_scores = torch.cat([final_scores_car,final_scores_pedcyc],0) + final_labels = torch.cat([final_labels_car,final_labels_pedcyc],0) + final_boxes = torch.cat([final_boxes_car,final_boxes_pedcyc],0) + + ######### Car DONOT Using NMS ###### + + recall_dict = self.generate_recall_record( + box_preds=final_boxes if 'rois' not in batch_dict else src_box_preds, + recall_dict=recall_dict, batch_index=index, data_dict=batch_dict, + thresh_list=post_process_cfg.RECALL_THRESH_LIST + ) + + + record_dict = { + 'pred_boxes': final_boxes[:,:7], + 'pred_scores': final_scores, + 'pred_labels': final_labels + } + pred_dicts.append(record_dict) + + return pred_dicts, recall_dict + diff --git a/pcdet/models/model_utils/mppnet_utils.py b/pcdet/models/model_utils/mppnet_utils.py new file mode 100644 index 000000000..10641ad3a --- /dev/null +++ b/pcdet/models/model_utils/mppnet_utils.py @@ -0,0 +1,420 @@ +from os import getgrouplist +import torch.nn as nn +import torch +import numpy as np +import torch.nn.functional as F +from typing import Optional, List +from torch import Tensor +from torch.nn.init import xavier_uniform_, zeros_, kaiming_normal_ + + +class PointNetfeat(nn.Module): + def __init__(self, input_dim, x=1,outchannel=512): + super(PointNetfeat, self).__init__() + if outchannel==256: + self.output_channel = 256 + else: + self.output_channel = 512 * x + self.conv1 = torch.nn.Conv1d(input_dim, 64 * x, 1) + self.conv2 = torch.nn.Conv1d(64 * x, 128 * x, 1) + self.conv3 = torch.nn.Conv1d(128 * x, 256 * x, 1) + self.conv4 = torch.nn.Conv1d(256 * x, self.output_channel, 1) + self.bn1 = nn.BatchNorm1d(64 * x) + self.bn2 = nn.BatchNorm1d(128 * x) + self.bn3 = nn.BatchNorm1d(256 * x) + self.bn4 = nn.BatchNorm1d(self.output_channel) + + def forward(self, x): + x = F.relu(self.bn1(self.conv1(x))) + x = F.relu(self.bn2(self.conv2(x))) + x = F.relu(self.bn3(self.conv3(x))) + x_ori = self.bn4(self.conv4(x)) + + x = torch.max(x_ori, 2, keepdim=True)[0] + + x = x.view(-1, self.output_channel) + return x, x_ori + +class PointNet(nn.Module): + def __init__(self, input_dim, joint_feat=False,model_cfg=None): + super(PointNet, self).__init__() + self.joint_feat = joint_feat + channels = model_cfg.TRANS_INPUT + + times=1 + self.feat = PointNetfeat(input_dim, 1) + + self.fc1 = nn.Linear(512, 256 ) + self.fc2 = nn.Linear(256, channels) + + self.pre_bn = nn.BatchNorm1d(input_dim) + self.bn1 = nn.BatchNorm1d(256) + self.bn2 = nn.BatchNorm1d(channels) + self.relu = nn.ReLU() + + self.fc_s1 = nn.Linear(channels*times, 256) + self.fc_s2 = nn.Linear(256, 3, bias=False) + self.fc_ce1 = nn.Linear(channels*times, 256) + self.fc_ce2 = nn.Linear(256, 3, bias=False) + self.fc_hr1 = nn.Linear(channels*times, 256) + self.fc_hr2 = nn.Linear(256, 1, bias=False) + + def forward(self, x, feat=None): + + if self.joint_feat: + if len(feat.shape) > 2: + feat = torch.max(feat, 2, keepdim=True)[0] + x = feat.view(-1, self.output_channel) + x = F.relu(self.bn1(self.fc1(x))) + feat = F.relu(self.bn2(self.fc2(x))) + else: + feat = feat + feat_traj = None + else: + x, feat_traj = self.feat(self.pre_bn(x)) + x = F.relu(self.bn1(self.fc1(x))) + feat = F.relu(self.bn2(self.fc2(x))) + + x = F.relu(self.fc_ce1(feat)) + centers = self.fc_ce2(x) + + x = F.relu(self.fc_s1(feat)) + sizes = self.fc_s2(x) + + x = F.relu(self.fc_hr1(feat)) + headings = self.fc_hr2(x) + + return torch.cat([centers, sizes, headings],-1),feat,feat_traj + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear): + kaiming_normal_(m.weight.data) + if m.bias is not None: + zeros_(m.bias) + +class MLP(nn.Module): + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + +class SpatialMixerBlock(nn.Module): + + def __init__(self,hidden_dim,grid_size,channels,config=None,dropout=0.0): + super().__init__() + + + self.mixer_x = MLP(input_dim = grid_size, hidden_dim = hidden_dim, output_dim = grid_size, num_layers = 3) + self.mixer_y = MLP(input_dim = grid_size, hidden_dim = hidden_dim, output_dim = grid_size, num_layers = 3) + self.mixer_z = MLP(input_dim = grid_size, hidden_dim = hidden_dim, output_dim = grid_size, num_layers = 3) + self.norm_x = nn.LayerNorm(channels) + self.norm_y = nn.LayerNorm(channels) + self.norm_z = nn.LayerNorm(channels) + self.norm_channel = nn.LayerNorm(channels) + self.ffn = nn.Sequential( + nn.Linear(channels, 2*channels), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(2*channels, channels), + ) + self.config = config + self.grid_size = grid_size + + def forward(self, src): + + src_3d = src.permute(1,2,0).contiguous().view(src.shape[1],src.shape[2], + self.grid_size,self.grid_size,self.grid_size) + src_3d = src_3d.permute(0,1,4,3,2).contiguous() + mixed_x = self.mixer_x(src_3d) + mixed_x = src_3d + mixed_x + mixed_x = self.norm_x(mixed_x.permute(0,2,3,4,1)).permute(0,4,1,2,3).contiguous() + + mixed_y = self.mixer_y(mixed_x.permute(0,1,2,4,3)).permute(0,1,2,4,3).contiguous() + mixed_y = mixed_x + mixed_y + mixed_y = self.norm_y(mixed_y.permute(0,2,3,4,1)).permute(0,4,1,2,3).contiguous() + + mixed_z = self.mixer_z(mixed_y.permute(0,1,4,3,2)).permute(0,1,4,3,2).contiguous() + + mixed_z = mixed_y + mixed_z + mixed_z = self.norm_z(mixed_z.permute(0,2,3,4,1)).permute(0,4,1,2,3).contiguous() + + src_mixer = mixed_z.view(src.shape[1],src.shape[2],-1).permute(2,0,1) + src_mixer = src_mixer + self.ffn(src_mixer) + src_mixer = self.norm_channel(src_mixer) + + return src_mixer + +class Transformer(nn.Module): + + def __init__(self, config, d_model=512, nhead=8, num_encoder_layers=6, + dim_feedforward=2048, dropout=0.1,activation="relu", normalize_before=False, + num_lidar_points=None,num_proxy_points=None, share_head=True,num_groups=None, + sequence_stride=None,num_frames=None): + super().__init__() + + self.config = config + self.share_head = share_head + self.num_frames = num_frames + self.nhead = nhead + self.sequence_stride = sequence_stride + self.num_groups = num_groups + self.num_proxy_points = num_proxy_points + self.num_lidar_points = num_lidar_points + self.d_model = d_model + self.nhead = nhead + encoder_layer = [TransformerEncoderLayer(self.config, d_model, nhead, dim_feedforward,dropout, activation, + normalize_before, num_lidar_points,num_groups=num_groups) for i in range(num_encoder_layers)] + + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm,self.config) + + self.token = nn.Parameter(torch.zeros(self.num_groups, 1, d_model)) + + + if self.num_frames >4: + + self.group_length = self.num_frames // self.num_groups + self.fusion_all_group = MLP(input_dim = self.config.hidden_dim*self.group_length, + hidden_dim = self.config.hidden_dim, output_dim = self.config.hidden_dim, num_layers = 4) + + self.fusion_norm = FFN(d_model, dim_feedforward) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, pos=None): + + BS, N, C = src.shape + if not pos is None: + pos = pos.permute(1, 0, 2) + + if self.num_frames == 16: + token_list = [self.token[i:(i+1)].repeat(BS,1,1) for i in range(self.num_groups)] + if self.sequence_stride ==1: + src_groups = src.view(src.shape[0],src.shape[1]//self.num_groups ,-1).chunk(4,dim=1) + + elif self.sequence_stride ==4: + src_groups = [] + + for i in range(self.num_groups): + groups = [] + for j in range(self.group_length): + points_index_start = (i+j*self.sequence_stride)*self.num_proxy_points + points_index_end = points_index_start + self.num_proxy_points + groups.append(src[:,points_index_start:points_index_end]) + + groups = torch.cat(groups,-1) + src_groups.append(groups) + + else: + raise NotImplementedError + + src_merge = torch.cat(src_groups,1) + src = self.fusion_norm(src[:,:self.num_groups*self.num_proxy_points],self.fusion_all_group(src_merge)) + src = [torch.cat([token_list[i],src[:,i*self.num_proxy_points:(i+1)*self.num_proxy_points]],dim=1) for i in range(self.num_groups)] + src = torch.cat(src,dim=0) + + else: + token_list = [self.token[i:(i+1)].repeat(BS,1,1) for i in range(self.num_groups)] + src = [torch.cat([token_list[i],src[:,i*self.num_proxy_points:(i+1)*self.num_proxy_points]],dim=1) for i in range(self.num_groups)] + src = torch.cat(src,dim=0) + + src = src.permute(1, 0, 2) + memory,tokens = self.encoder(src,pos=pos) + + memory = torch.cat(memory[0:1].chunk(4,dim=1),0) + return memory, tokens + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None,config=None): + super().__init__() + self.layers = nn.ModuleList(encoder_layer) + self.num_layers = num_layers + self.norm = norm + self.config = config + + def forward(self, src, + pos: Optional[Tensor] = None): + + token_list = [] + output = src + for layer in self.layers: + output,tokens = layer(output,pos=pos) + token_list.append(tokens) + if self.norm is not None: + output = self.norm(output) + + return output,token_list + + +class TransformerEncoderLayer(nn.Module): + count = 0 + def __init__(self, config, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False,num_points=None,num_groups=None): + super().__init__() + TransformerEncoderLayer.count += 1 + self.layer_count = TransformerEncoderLayer.count + self.config = config + self.num_point = num_points + self.num_groups= num_groups + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + if self.layer_count <= self.config.enc_layers-1: + self.cross_attn_layers = nn.ModuleList() + for _ in range(self.num_groups): + self.cross_attn_layers.append(nn.MultiheadAttention(d_model, nhead, dropout=dropout)) + + self.ffn = FFN(d_model, dim_feedforward) + self.fusion_all_groups = MLP(input_dim = d_model*4, hidden_dim = d_model, output_dim = d_model, num_layers = 4) + + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self.mlp_mixer_3d = SpatialMixerBlock(self.config.use_mlp_mixer.hidden_dim,self.config.use_mlp_mixer.get('grid_size', 4),self.config.hidden_dim, self.config.use_mlp_mixer) + + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + pos: Optional[Tensor] = None): + + src_intra_group_fusion = self.mlp_mixer_3d(src[1:]) + src = torch.cat([src[:1],src_intra_group_fusion],0) + + token = src[:1] + + if not pos is None: + key = self.with_pos_embed(src_intra_group_fusion, pos[1:]) + else: + key = src_intra_group_fusion + + src_summary = self.self_attn(token, key, value=src_intra_group_fusion)[0] + token = token + self.dropout1(src_summary) + token = self.norm1(token) + src_summary = self.linear2(self.dropout(self.activation(self.linear1(token)))) + token = token + self.dropout2(src_summary) + token = self.norm2(token) + src = torch.cat([token,src[1:]],0) + + if self.layer_count <= self.config.enc_layers-1: + + src_all_groups = src[1:].view((src.shape[0]-1)*4,-1,src.shape[-1]) + src_groups_list = src_all_groups.chunk(self.num_groups,0) + + src_all_groups = torch.cat(src_groups_list,-1) + src_all_groups_fusion = self.fusion_all_groups(src_all_groups) + + key = self.with_pos_embed(src_all_groups_fusion, pos[1:]) + query_list = [self.with_pos_embed(query, pos[1:]) for query in src_groups_list] + + inter_group_fusion_list = [] + for i in range(self.num_groups): + inter_group_fusion = self.cross_attn_layers[i](query_list[i], key, value=src_all_groups_fusion)[0] + inter_group_fusion = self.ffn(src_groups_list[i],inter_group_fusion) + inter_group_fusion_list.append(inter_group_fusion) + + src_inter_group_fusion = torch.cat(inter_group_fusion_list,1) + + src = torch.cat([src[:1],src_inter_group_fusion],0) + + return src, torch.cat(src[:1].chunk(4,1),0) + + def forward_pre(self, src, + pos: Optional[Tensor] = None): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, src, + pos: Optional[Tensor] = None): + + if self.normalize_before: + return self.forward_pre(src, pos) + return self.forward_post(src, pos) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + +class FFN(nn.Module): + def __init__(self, d_model, dim_feedforward=2048, dropout=0.1,dout=None, + activation="relu", normalize_before=False): + super().__init__() + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def forward(self, tgt,tgt_input): + tgt = tgt + self.dropout2(tgt_input) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + + return tgt + +def build_transformer(args): + return Transformer( + config = args, + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + normalize_before=args.pre_norm, + num_lidar_points = args.num_lidar_points, + num_proxy_points = args.num_proxy_points, + num_frames = args.num_frames, + sequence_stride = args.get('sequence_stride',1), + num_groups=args.num_groups, + ) + diff --git a/pcdet/models/roi_heads/__init__.py b/pcdet/models/roi_heads/__init__.py index e9170e077..693cec426 100644 --- a/pcdet/models/roi_heads/__init__.py +++ b/pcdet/models/roi_heads/__init__.py @@ -4,7 +4,8 @@ from .second_head import SECONDHead from .voxelrcnn_head import VoxelRCNNHead from .roi_head_template import RoIHeadTemplate - +from .mppnet_head import MPPNetHead +from .mppnet_memory_bank_e2e import MPPNetHeadE2E __all__ = { 'RoIHeadTemplate': RoIHeadTemplate, @@ -12,5 +13,7 @@ 'PVRCNNHead': PVRCNNHead, 'SECONDHead': SECONDHead, 'PointRCNNHead': PointRCNNHead, - 'VoxelRCNNHead': VoxelRCNNHead + 'VoxelRCNNHead': VoxelRCNNHead, + 'MPPNetHead': MPPNetHead, + 'MPPNetHeadE2E': MPPNetHeadE2E, } diff --git a/pcdet/models/roi_heads/mppnet_head.py b/pcdet/models/roi_heads/mppnet_head.py new file mode 100644 index 000000000..7f9911408 --- /dev/null +++ b/pcdet/models/roi_heads/mppnet_head.py @@ -0,0 +1,992 @@ +from typing import ValuesView +import torch.nn as nn +import torch +import numpy as np +import copy +import torch.nn.functional as F +from pcdet.ops.iou3d_nms import iou3d_nms_utils +from ...utils import common_utils, loss_utils +from .roi_head_template import RoIHeadTemplate +from ..model_utils.mppnet_utils import build_transformer, PointNet, MLP +from .target_assigner.proposal_target_layer import ProposalTargetLayer +from pcdet.ops.pointnet2.pointnet2_stack import pointnet2_modules as pointnet2_stack_modules + + +class ProposalTargetLayerMPPNet(ProposalTargetLayer): + def __init__(self, roi_sampler_cfg): + super().__init__(roi_sampler_cfg = roi_sampler_cfg) + + def forward(self, batch_dict): + """ + Args: + batch_dict: + batch_size: + rois: (B, num_rois, 7 + C) + roi_scores: (B, num_rois) + gt_boxes: (B, N, 7 + C + 1) + roi_labels: (B, num_rois) + Returns: + batch_dict: + rois: (B, M, 7 + C) + gt_of_rois: (B, M, 7 + C) + gt_iou_of_rois: (B, M) + roi_scores: (B, M) + roi_labels: (B, M) + reg_valid_mask: (B, M) + rcnn_cls_labels: (B, M) + """ + + batch_rois, batch_gt_of_rois, batch_roi_ious, batch_roi_scores, batch_roi_labels, \ + batch_trajectory_rois,batch_valid_length = self.sample_rois_for_mppnet(batch_dict=batch_dict) + + # regression valid mask + reg_valid_mask = (batch_roi_ious > self.roi_sampler_cfg.REG_FG_THRESH).long() + + # classification label + if self.roi_sampler_cfg.CLS_SCORE_TYPE == 'cls': + batch_cls_labels = (batch_roi_ious > self.roi_sampler_cfg.CLS_FG_THRESH).long() + ignore_mask = (batch_roi_ious > self.roi_sampler_cfg.CLS_BG_THRESH) & \ + (batch_roi_ious < self.roi_sampler_cfg.CLS_FG_THRESH) + batch_cls_labels[ignore_mask > 0] = -1 + elif self.roi_sampler_cfg.CLS_SCORE_TYPE == 'roi_iou': + iou_bg_thresh = self.roi_sampler_cfg.CLS_BG_THRESH + iou_fg_thresh = self.roi_sampler_cfg.CLS_FG_THRESH + fg_mask = batch_roi_ious > iou_fg_thresh + bg_mask = batch_roi_ious < iou_bg_thresh + interval_mask = (fg_mask == 0) & (bg_mask == 0) + + batch_cls_labels = (fg_mask > 0).float() + batch_cls_labels[interval_mask] = \ + (batch_roi_ious[interval_mask] - iou_bg_thresh) / (iou_fg_thresh - iou_bg_thresh) + else: + raise NotImplementedError + + + targets_dict = {'rois': batch_rois, 'gt_of_rois': batch_gt_of_rois, + 'gt_iou_of_rois': batch_roi_ious,'roi_scores': batch_roi_scores, + 'roi_labels': batch_roi_labels,'reg_valid_mask': reg_valid_mask, + 'rcnn_cls_labels': batch_cls_labels,'trajectory_rois':batch_trajectory_rois, + 'valid_length': batch_valid_length, + } + + return targets_dict + + def sample_rois_for_mppnet(self, batch_dict): + """ + Args: + batch_dict: + batch_size: + rois: (B, num_rois, 7 + C) + roi_scores: (B, num_rois) + gt_boxes: (B, N, 7 + C + 1) + roi_labels: (B, num_rois) + Returns: + """ + cur_frame_idx = 0 + batch_size = batch_dict['batch_size'] + rois = batch_dict['trajectory_rois'][:,cur_frame_idx,:,:] + roi_scores = batch_dict['roi_scores'][:,:,cur_frame_idx] + roi_labels = batch_dict['roi_labels'] + gt_boxes = batch_dict['gt_boxes'] + + code_size = rois.shape[-1] + batch_rois = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE, code_size) + batch_gt_of_rois = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE, gt_boxes.shape[-1]) + batch_roi_ious = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE) + batch_roi_scores = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE) + batch_roi_labels = rois.new_zeros((batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE), dtype=torch.long) + + + + trajectory_rois = batch_dict['trajectory_rois'] + batch_trajectory_rois = rois.new_zeros(batch_size, trajectory_rois.shape[1],self.roi_sampler_cfg.ROI_PER_IMAGE,trajectory_rois.shape[-1]) + + valid_length = batch_dict['valid_length'] + batch_valid_length = rois.new_zeros((batch_size, batch_dict['trajectory_rois'].shape[1], self.roi_sampler_cfg.ROI_PER_IMAGE)) + + for index in range(batch_size): + + cur_trajectory_rois = trajectory_rois[index] + + cur_roi, cur_gt, cur_roi_labels, cur_roi_scores = rois[index],gt_boxes[index], roi_labels[index], roi_scores[index] + + if 'valid_length' in batch_dict.keys(): + cur_valid_length = valid_length[index] + + + + k = cur_gt.__len__() - 1 + while k > 0 and cur_gt[k].sum() == 0: + k -= 1 + + cur_gt = cur_gt[:k + 1] + cur_gt = cur_gt.new_zeros((1, cur_gt.shape[1])) if len(cur_gt) == 0 else cur_gt + + if self.roi_sampler_cfg.get('SAMPLE_ROI_BY_EACH_CLASS', False): + max_overlaps, gt_assignment = self.get_max_iou_with_same_class( + rois=cur_roi, roi_labels=cur_roi_labels, + gt_boxes=cur_gt[:, 0:7], gt_labels=cur_gt[:, -1].long() + ) + + else: + iou3d = iou3d_nms_utils.boxes_iou3d_gpu(cur_roi, cur_gt[:, 0:7]) # (M, N) + max_overlaps, gt_assignment = torch.max(iou3d, dim=1) + + sampled_inds,fg_inds, bg_inds = self.subsample_rois(max_overlaps=max_overlaps) + + batch_roi_labels[index] = cur_roi_labels[sampled_inds.long()] + + + if self.roi_sampler_cfg.get('USE_ROI_AUG',False): + + fg_rois, fg_iou3d = self.aug_roi_by_noise_torch(cur_roi[fg_inds], cur_gt[gt_assignment[fg_inds]], + max_overlaps[fg_inds], aug_times=self.roi_sampler_cfg.ROI_FG_AUG_TIMES) + bg_rois = cur_roi[bg_inds] + bg_iou3d = max_overlaps[bg_inds] + + batch_rois[index] = torch.cat([fg_rois,bg_rois],0) + batch_roi_ious[index] = torch.cat([fg_iou3d,bg_iou3d],0) + batch_gt_of_rois[index] = cur_gt[gt_assignment[sampled_inds]] + + else: + batch_rois[index] = cur_roi[sampled_inds] + batch_roi_ious[index] = max_overlaps[sampled_inds] + batch_gt_of_rois[index] = cur_gt[gt_assignment[sampled_inds]] + + + batch_roi_scores[index] = cur_roi_scores[sampled_inds] + + if 'valid_length' in batch_dict.keys(): + batch_valid_length[index] = cur_valid_length[:,sampled_inds] + + if self.roi_sampler_cfg.USE_TRAJ_AUG.ENABLED: + batch_trajectory_rois_list = [] + for idx in range(0,batch_dict['num_frames']): + if idx== cur_frame_idx: + batch_trajectory_rois_list.append(cur_trajectory_rois[cur_frame_idx:cur_frame_idx+1,sampled_inds]) + continue + fg_trajs, _ = self.aug_roi_by_noise_torch(cur_trajectory_rois[idx,fg_inds], cur_trajectory_rois[idx,fg_inds][:,:8], max_overlaps[fg_inds], \ + aug_times=self.roi_sampler_cfg.ROI_FG_AUG_TIMES,pos_thresh=self.roi_sampler_cfg.USE_TRAJ_AUG.THRESHOD) + bg_trajs = cur_trajectory_rois[idx,bg_inds] + batch_trajectory_rois_list.append(torch.cat([fg_trajs,bg_trajs],0)[None,:,:]) + batch_trajectory_rois[index] = torch.cat(batch_trajectory_rois_list,0) + else: + batch_trajectory_rois[index] = cur_trajectory_rois[:,sampled_inds] + + return batch_rois, batch_gt_of_rois, batch_roi_ious, batch_roi_scores, batch_roi_labels, batch_trajectory_rois,batch_valid_length + + def subsample_rois(self, max_overlaps): + # sample fg, easy_bg, hard_bg + fg_rois_per_image = int(np.round(self.roi_sampler_cfg.FG_RATIO * self.roi_sampler_cfg.ROI_PER_IMAGE)) + fg_thresh = min(self.roi_sampler_cfg.REG_FG_THRESH, self.roi_sampler_cfg.CLS_FG_THRESH) + + fg_inds = ((max_overlaps >= fg_thresh)).nonzero().view(-1) + easy_bg_inds = ((max_overlaps < self.roi_sampler_cfg.CLS_BG_THRESH_LO)).nonzero().view(-1) + hard_bg_inds = ((max_overlaps < self.roi_sampler_cfg.REG_FG_THRESH) & + (max_overlaps >= self.roi_sampler_cfg.CLS_BG_THRESH_LO)).nonzero().view(-1) + + fg_num_rois = fg_inds.numel() + bg_num_rois = hard_bg_inds.numel() + easy_bg_inds.numel() + + if fg_num_rois > 0 and bg_num_rois > 0: + # sampling fg + fg_rois_per_this_image = min(fg_rois_per_image, fg_num_rois) + + rand_num = torch.from_numpy(np.random.permutation(fg_num_rois)).type_as(max_overlaps).long() + fg_inds = fg_inds[rand_num[:fg_rois_per_this_image]] + + # sampling bg + bg_rois_per_this_image = self.roi_sampler_cfg.ROI_PER_IMAGE - fg_rois_per_this_image + bg_inds = self.sample_bg_inds( + hard_bg_inds, easy_bg_inds, bg_rois_per_this_image, self.roi_sampler_cfg.HARD_BG_RATIO + ) + + elif fg_num_rois > 0 and bg_num_rois == 0: + # sampling fg + rand_num = np.floor(np.random.rand(self.roi_sampler_cfg.ROI_PER_IMAGE) * fg_num_rois) + rand_num = torch.from_numpy(rand_num).type_as(max_overlaps).long() + fg_inds = fg_inds[rand_num] + bg_inds = torch.tensor([]).type_as(fg_inds) + + elif bg_num_rois > 0 and fg_num_rois == 0: + # sampling bg + bg_rois_per_this_image = self.roi_sampler_cfg.ROI_PER_IMAGE + bg_inds = self.sample_bg_inds( + hard_bg_inds, easy_bg_inds, bg_rois_per_this_image, self.roi_sampler_cfg.HARD_BG_RATIO + ) + else: + print('maxoverlaps:(min=%f, max=%f)' % (max_overlaps.min().item(), max_overlaps.max().item())) + print('ERROR: FG=%d, BG=%d' % (fg_num_rois, bg_num_rois)) + raise NotImplementedError + + sampled_inds = torch.cat((fg_inds, bg_inds), dim=0) + return sampled_inds.long(), fg_inds.long(), bg_inds.long() + + def aug_roi_by_noise_torch(self,roi_boxes3d, gt_boxes3d, iou3d_src, aug_times=10, pos_thresh=None): + iou_of_rois = torch.zeros(roi_boxes3d.shape[0]).type_as(gt_boxes3d) + if pos_thresh is None: + pos_thresh = min(self.roi_sampler_cfg.REG_FG_THRESH, self.roi_sampler_cfg.CLS_FG_THRESH) + + for k in range(roi_boxes3d.shape[0]): + temp_iou = cnt = 0 + roi_box3d = roi_boxes3d[k] + + gt_box3d = gt_boxes3d[k].view(1, gt_boxes3d.shape[-1]) + aug_box3d = roi_box3d + keep = True + while temp_iou < pos_thresh and cnt < aug_times: + if np.random.rand() <= self.roi_sampler_cfg.RATIO: + aug_box3d = roi_box3d # p=RATIO to keep the original roi box + keep = True + else: + aug_box3d = self.random_aug_box3d(roi_box3d) + keep = False + aug_box3d = aug_box3d.view((1, aug_box3d.shape[-1])) + iou3d = iou3d_nms_utils.boxes_iou3d_gpu(aug_box3d[:,:7], gt_box3d[:,:7]) + temp_iou = iou3d[0][0] + cnt += 1 + roi_boxes3d[k] = aug_box3d.view(-1) + if cnt == 0 or keep: + iou_of_rois[k] = iou3d_src[k] + else: + iou_of_rois[k] = temp_iou + return roi_boxes3d, iou_of_rois + + def random_aug_box3d(self,box3d): + """ + :param box3d: (7) [x, y, z, h, w, l, ry] + random shift, scale, orientation + """ + + if self.roi_sampler_cfg.REG_AUG_METHOD == 'single': + pos_shift = (torch.rand(3, device=box3d.device) - 0.5) # [-0.5 ~ 0.5] + hwl_scale = (torch.rand(3, device=box3d.device) - 0.5) / (0.5 / 0.15) + 1.0 # + angle_rot = (torch.rand(1, device=box3d.device) - 0.5) / (0.5 / (np.pi / 12)) # [-pi/12 ~ pi/12] + aug_box3d = torch.cat([box3d[0:3] + pos_shift, box3d[3:6] * hwl_scale, box3d[6:7] + angle_rot, box3d[7:]], dim=0) + return aug_box3d + elif self.roi_sampler_cfg.REG_AUG_METHOD == 'multiple': + # pos_range, hwl_range, angle_range, mean_iou + range_config = [[0.2, 0.1, np.pi / 12, 0.7], + [0.3, 0.15, np.pi / 12, 0.6], + [0.5, 0.15, np.pi / 9, 0.5], + [0.8, 0.15, np.pi / 6, 0.3], + [1.0, 0.15, np.pi / 3, 0.2]] + idx = torch.randint(low=0, high=len(range_config), size=(1,))[0].long() + + pos_shift = ((torch.rand(3, device=box3d.device) - 0.5) / 0.5) * range_config[idx][0] + hwl_scale = ((torch.rand(3, device=box3d.device) - 0.5) / 0.5) * range_config[idx][1] + 1.0 + angle_rot = ((torch.rand(1, device=box3d.device) - 0.5) / 0.5) * range_config[idx][2] + + aug_box3d = torch.cat([box3d[0:3] + pos_shift, box3d[3:6] * hwl_scale, box3d[6:7] + angle_rot], dim=0) + return aug_box3d + elif self.roi_sampler_cfg.REG_AUG_METHOD == 'normal': + x_shift = np.random.normal(loc=0, scale=0.3) + y_shift = np.random.normal(loc=0, scale=0.2) + z_shift = np.random.normal(loc=0, scale=0.3) + h_shift = np.random.normal(loc=0, scale=0.25) + w_shift = np.random.normal(loc=0, scale=0.15) + l_shift = np.random.normal(loc=0, scale=0.5) + ry_shift = ((torch.rand() - 0.5) / 0.5) * np.pi / 12 + + aug_box3d = np.array([box3d[0] + x_shift, box3d[1] + y_shift, box3d[2] + z_shift, box3d[3] + h_shift, + box3d[4] + w_shift, box3d[5] + l_shift, box3d[6] + ry_shift], dtype=np.float32) + aug_box3d = torch.from_numpy(aug_box3d).type_as(box3d) + return aug_box3d + else: + raise NotImplementedError + +class MPPNetHead(RoIHeadTemplate): + def __init__(self,model_cfg, num_class=1,**kwargs): + super().__init__(num_class=num_class, model_cfg=model_cfg) + self.model_cfg = model_cfg + self.proposal_target_layer = ProposalTargetLayerMPPNet(roi_sampler_cfg=self.model_cfg.TARGET_CONFIG) + self.use_time_stamp = self.model_cfg.get('USE_TIMESTAMP',None) + self.num_lidar_points = self.model_cfg.Transformer.num_lidar_points + self.avg_stage1_score = self.model_cfg.get('AVG_STAGE1_SCORE', None) + + self.nhead = model_cfg.Transformer.nheads + self.num_enc_layer = model_cfg.Transformer.enc_layers + hidden_dim = model_cfg.TRANS_INPUT + self.hidden_dim = model_cfg.TRANS_INPUT + self.num_groups = model_cfg.Transformer.num_groups + + self.grid_size = model_cfg.ROI_GRID_POOL.GRID_SIZE + self.num_proxy_points = model_cfg.Transformer.num_proxy_points + self.seqboxembed = PointNet(8,model_cfg=self.model_cfg) + self.jointembed = MLP(self.hidden_dim*(self.num_groups+1), model_cfg.Transformer.hidden_dim, self.box_coder.code_size * self.num_class, 4) + + + num_radius = len(self.model_cfg.ROI_GRID_POOL.POOL_RADIUS) + self.up_dimension_geometry = MLP(input_dim = 29, hidden_dim = 64, output_dim =hidden_dim//num_radius, num_layers = 3) + self.up_dimension_motion = MLP(input_dim = 30, hidden_dim = 64, output_dim = hidden_dim, num_layers = 3) + + self.transformer = build_transformer(model_cfg.Transformer) + + self.roi_grid_pool_layer = pointnet2_stack_modules.StackSAModuleMSG( + radii=self.model_cfg.ROI_GRID_POOL.POOL_RADIUS, + nsamples=self.model_cfg.ROI_GRID_POOL.NSAMPLE, + mlps=self.model_cfg.ROI_GRID_POOL.MLPS, + use_xyz=True, + pool_method=self.model_cfg.ROI_GRID_POOL.POOL_METHOD, + ) + + self.class_embed = nn.ModuleList() + self.class_embed.append(nn.Linear(model_cfg.Transformer.hidden_dim, 1)) + + self.bbox_embed = nn.ModuleList() + for _ in range(self.num_groups): + self.bbox_embed.append(MLP(model_cfg.Transformer.hidden_dim, model_cfg.Transformer.hidden_dim, self.box_coder.code_size * self.num_class, 4)) + + if self.model_cfg.Transformer.use_grid_pos.enabled: + if self.model_cfg.Transformer.use_grid_pos.init_type == 'index': + self.grid_index = torch.cat([i.reshape(-1,1)for i in torch.meshgrid(torch.arange(self.grid_size), torch.arange(self.grid_size), torch.arange(self.grid_size))],1).float().cuda() + self.grid_pos_embeded = MLP(input_dim = 3, hidden_dim = 256, output_dim = hidden_dim, num_layers = 2) + else: + self.pos = nn.Parameter(torch.zeros(1, self.num_grid_points, 256)) + + def init_weights(self, weight_init='xavier'): + if weight_init == 'kaiming': + init_func = nn.init.kaiming_normal_ + elif weight_init == 'xavier': + init_func = nn.init.xavier_normal_ + elif weight_init == 'normal': + init_func = nn.init.normal_ + else: + raise NotImplementedError + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): + if weight_init == 'normal': + init_func(m.weight, mean=0, std=0.001) + else: + init_func(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + nn.init.normal_(self.bbox_embed.layers[-1].weight, mean=0, std=0.001) + + def get_corner_points_of_roi(self, rois): + rois = rois.view(-1, rois.shape[-1]) + batch_size_rcnn = rois.shape[0] + + local_roi_grid_points = self.get_corner_points(rois, batch_size_rcnn) + local_roi_grid_points = common_utils.rotate_points_along_z( + local_roi_grid_points.clone(), rois[:, 6] + ).squeeze(dim=1) + global_center = rois[:, 0:3].clone() + + global_roi_grid_points = local_roi_grid_points + global_center.unsqueeze(dim=1) + return global_roi_grid_points, local_roi_grid_points + + @staticmethod + def get_dense_grid_points(rois, batch_size_rcnn, grid_size): + faked_features = rois.new_ones((grid_size, grid_size, grid_size)) + dense_idx = faked_features.nonzero() + dense_idx = dense_idx.repeat(batch_size_rcnn, 1, 1).float() + + local_roi_size = rois.view(batch_size_rcnn, -1)[:, 3:6] + roi_grid_points = (dense_idx + 0.5) / grid_size * local_roi_size.unsqueeze(dim=1) \ + - (local_roi_size.unsqueeze(dim=1) / 2) + return roi_grid_points + + @staticmethod + def get_corner_points(rois, batch_size_rcnn): + faked_features = rois.new_ones((2, 2, 2)) + + dense_idx = faked_features.nonzero() + dense_idx = dense_idx.repeat(batch_size_rcnn, 1, 1).float() + + local_roi_size = rois.view(batch_size_rcnn, -1)[:, 3:6] + roi_grid_points = dense_idx * local_roi_size.unsqueeze(dim=1) \ + - (local_roi_size.unsqueeze(dim=1) / 2) + return roi_grid_points + + def roi_grid_pool(self, batch_size, rois, point_coords, point_features,batch_dict=None,batch_cnt=None): + + num_frames = batch_dict['num_frames'] + num_rois = rois.shape[2]*rois.shape[1] + + global_roi_proxy_points, local_roi_proxy_points = self.get_proxy_points_of_roi( + rois.permute(0,2,1,3).contiguous(), grid_size=self.grid_size + ) + + global_roi_proxy_points = global_roi_proxy_points.view(batch_size, -1, 3) + + + point_coords = point_coords.view(point_coords.shape[0]*num_frames,point_coords.shape[1]//num_frames,point_coords.shape[-1]) + xyz = point_coords[:, :, 0:3].view(-1,3) + + + num_points = point_coords.shape[1] + num_proxy_points = self.num_proxy_points + + if batch_cnt is None: + xyz_batch_cnt = torch.tensor([num_points]*num_rois*batch_size).cuda().int() + else: + xyz_batch_cnt = torch.tensor(batch_cnt).cuda().int() + + new_xyz_batch_cnt = torch.tensor([num_proxy_points]*num_rois*batch_size).cuda().int() + new_xyz = global_roi_proxy_points.view(-1, 3) + + _, pooled_features = self.roi_grid_pool_layer( + xyz=xyz.contiguous(), + xyz_batch_cnt=xyz_batch_cnt, + new_xyz=new_xyz, + new_xyz_batch_cnt=new_xyz_batch_cnt, + features=point_features.view(-1,point_features.shape[-1]).contiguous(), + ) + + features = pooled_features.view( + point_features.shape[0], num_frames*self.num_proxy_points, + pooled_features.shape[-1]).contiguous() + + return features,global_roi_proxy_points.view(batch_size*rois.shape[2], num_frames*num_proxy_points,3).contiguous() + + def get_proxy_points_of_roi(self, rois, grid_size): + rois = rois.view(-1, rois.shape[-1]) + batch_size_rcnn = rois.shape[0] + + local_roi_grid_points = self.get_dense_grid_points(rois, batch_size_rcnn, grid_size) + local_roi_grid_points = common_utils.rotate_points_along_z(local_roi_grid_points.clone(), rois[:, 6]).squeeze(dim=1) + global_center = rois[:, 0:3].clone() + global_roi_grid_points = local_roi_grid_points + global_center.unsqueeze(dim=1) + return global_roi_grid_points, local_roi_grid_points + + def spherical_coordinate(self, src, diag_dist): + assert (src.shape[-1] == 27) + device = src.device + indices_x = torch.LongTensor([0,3,6,9,12,15,18,21,24]).to(device) # + indices_y = torch.LongTensor([1,4,7,10,13,16,19,22,25]).to(device) # + indices_z = torch.LongTensor([2,5,8,11,14,17,20,23,26]).to(device) + src_x = torch.index_select(src, -1, indices_x) + src_y = torch.index_select(src, -1, indices_y) + src_z = torch.index_select(src, -1, indices_z) + dis = (src_x ** 2 + src_y ** 2 + src_z ** 2) ** 0.5 + phi = torch.atan(src_y / (src_x + 1e-5)) + the = torch.acos(src_z / (dis + 1e-5)) + dis = dis / (diag_dist + 1e-5) + src = torch.cat([dis, phi, the], dim = -1) + return src + + def crop_current_frame_points(self, src, batch_size,trajectory_rois,num_rois,batch_dict): + + for bs_idx in range(batch_size): + cur_batch_boxes = trajectory_rois[bs_idx,0,:,:7].view(-1,7) + cur_radiis = torch.sqrt((cur_batch_boxes[:,3]/2) ** 2 + (cur_batch_boxes[:,4]/2) ** 2) * 1.1 + cur_points = batch_dict['points'][(batch_dict['points'][:, 0] == bs_idx)][:,1:] + dis = torch.norm((cur_points[:,:2].unsqueeze(0) - cur_batch_boxes[:,:2].unsqueeze(1).repeat(1,cur_points.shape[0],1)), dim = 2) + point_mask = (dis <= cur_radiis.unsqueeze(-1)) + + + sampled_idx = torch.topk(point_mask.float(),128)[1] + sampled_idx_buffer = sampled_idx[:, 0:1].repeat(1, 128) + roi_idx = torch.arange(num_rois)[:, None].repeat(1, 128) + sampled_mask = point_mask[roi_idx, sampled_idx] + sampled_idx_buffer[sampled_mask] = sampled_idx[sampled_mask] + + src[bs_idx] = cur_points[sampled_idx_buffer][:,:,:5] + empty_flag = sampled_mask.sum(-1)==0 + src[bs_idx,empty_flag] = 0 + + src = src.repeat([1,1,trajectory_rois.shape[1],1]) + + return src + + def crop_previous_frame_points(self,src,batch_size,trajectory_rois,num_rois,valid_length,batch_dict): + for bs_idx in range(batch_size): + + cur_points = batch_dict['points'][(batch_dict['points'][:, 0] == bs_idx)][:,1:] + + + for idx in range(1,trajectory_rois.shape[1]): + + time_mask = (cur_points[:,-1] - idx*0.1).abs() < 1e-3 + cur_time_points = cur_points[time_mask] + cur_batch_boxes = trajectory_rois[bs_idx,idx,:,:7].view(-1,7) + + cur_radiis = torch.sqrt((cur_batch_boxes[:,3]/2) ** 2 + (cur_batch_boxes[:,4]/2) ** 2) * 1.1 + if not self.training and cur_batch_boxes.shape[0] > 32: + length_iter= cur_batch_boxes.shape[0]//32 + dis_list = [] + for i in range(length_iter+1): + dis = torch.norm((cur_time_points[:,:2].unsqueeze(0) - \ + cur_batch_boxes[32*i:32*(i+1),:2].unsqueeze(1).repeat(1,cur_time_points.shape[0],1)), dim = 2) + dis_list.append(dis) + dis = torch.cat(dis_list,0) + else: + dis = torch.norm((cur_time_points[:,:2].unsqueeze(0) - \ + cur_batch_boxes[:,:2].unsqueeze(1).repeat(1,cur_time_points.shape[0],1)), dim = 2) + + point_mask = (dis <= cur_radiis.unsqueeze(-1)).view(trajectory_rois.shape[2],-1) + + for roi_box_idx in range(0, num_rois): + + if not valid_length[bs_idx,idx,roi_box_idx]: + continue + + cur_roi_points = cur_time_points[point_mask[roi_box_idx]] + + if cur_roi_points.shape[0] > self.num_lidar_points: + np.random.seed(0) + choice = np.random.choice(cur_roi_points.shape[0], self.num_lidar_points, replace=True) + cur_roi_points_sample = cur_roi_points[choice] + + elif cur_roi_points.shape[0] == 0: + cur_roi_points_sample = cur_roi_points.new_zeros(self.num_lidar_points, 6) + + else: + empty_num = self.num_lidar_points - cur_roi_points.shape[0] + add_zeros = cur_roi_points.new_zeros(empty_num, 6) + add_zeros = cur_roi_points[0].repeat(empty_num, 1) + cur_roi_points_sample = torch.cat([cur_roi_points, add_zeros], dim = 0) + + if not self.use_time_stamp: + cur_roi_points_sample = cur_roi_points_sample[:,:-1] + + src[bs_idx, roi_box_idx, self.num_lidar_points*idx:self.num_lidar_points*(idx+1), :] = cur_roi_points_sample + + + return src + + + def get_proposal_aware_geometry_feature(self,src, batch_size,trajectory_rois,num_rois,batch_dict): + proposal_aware_feat_list = [] + for i in range(trajectory_rois.shape[1]): + + corner_points, _ = self.get_corner_points_of_roi(trajectory_rois[:,i,:,:].contiguous()) + + corner_points = corner_points.view(batch_size, num_rois, -1, corner_points.shape[-1]) + corner_points = corner_points.view(batch_size * num_rois, -1) + trajectory_roi_center = trajectory_rois[:,i,:,:].contiguous().reshape(batch_size * num_rois, -1)[:,:3] + corner_add_center_points = torch.cat([corner_points, trajectory_roi_center], dim = -1) + proposal_aware_feat = src[:,i*self.num_lidar_points:(i+1)*self.num_lidar_points,:3].repeat(1,1,9) - \ + corner_add_center_points.unsqueeze(1).repeat(1,self.num_lidar_points,1) + + lwh = trajectory_rois[:,i,:,:].reshape(batch_size * num_rois, -1)[:,3:6].unsqueeze(1).repeat(1,proposal_aware_feat.shape[1],1) + diag_dist = (lwh[:,:,0]**2 + lwh[:,:,1]**2 + lwh[:,:,2]**2) ** 0.5 + proposal_aware_feat = self.spherical_coordinate(proposal_aware_feat, diag_dist = diag_dist.unsqueeze(-1)) + proposal_aware_feat_list.append(proposal_aware_feat) + + proposal_aware_feat = torch.cat(proposal_aware_feat_list,dim=1) + proposal_aware_feat = torch.cat([proposal_aware_feat, src[:,:,3:]], dim = -1) + src_gemoetry = self.up_dimension_geometry(proposal_aware_feat) + proxy_point_geometry, proxy_points = self.roi_grid_pool(batch_size,trajectory_rois,src,src_gemoetry,batch_dict,batch_cnt=None) + return proxy_point_geometry,proxy_points + + + + def get_proposal_aware_motion_feature(self,proxy_point,batch_size,trajectory_rois,num_rois,batch_dict): + + + time_stamp = torch.ones([proxy_point.shape[0],proxy_point.shape[1],1]).cuda() + padding_zero = torch.zeros([proxy_point.shape[0],proxy_point.shape[1],2]).cuda() + proxy_point_time_padding = torch.cat([padding_zero,time_stamp],-1) + + num_frames = trajectory_rois.shape[1] + + for i in range(num_frames): + proxy_point_time_padding[:,i*self.num_proxy_points:(i+1)*self.num_proxy_points,-1] = i*0.1 + + + corner_points, _ = self.get_corner_points_of_roi(trajectory_rois[:,0,:,:].contiguous()) + corner_points = corner_points.view(batch_size, num_rois, -1, corner_points.shape[-1]) + corner_points = corner_points.view(batch_size * num_rois, -1) + trajectory_roi_center = trajectory_rois[:,0,:,:].reshape(batch_size * num_rois, -1)[:,:3] + corner_add_center_points = torch.cat([corner_points, trajectory_roi_center], dim = -1) + + proposal_aware_feat = proxy_point[:,:,:3].repeat(1,1,9) - corner_add_center_points.unsqueeze(1) + + lwh = trajectory_rois[:,0,:,:].reshape(batch_size * num_rois, -1)[:,3:6].unsqueeze(1).repeat(1,proxy_point.shape[1],1) + diag_dist = (lwh[:,:,0]**2 + lwh[:,:,1]**2 + lwh[:,:,2]**2) ** 0.5 + proposal_aware_feat = self.spherical_coordinate(proposal_aware_feat, diag_dist = diag_dist.unsqueeze(-1)) + + + proposal_aware_feat = torch.cat([proposal_aware_feat,proxy_point_time_padding],-1) + proxy_point_motion_feat = self.up_dimension_motion(proposal_aware_feat) + + return proxy_point_motion_feat + + def trajectories_auxiliary_branch(self,trajectory_rois): + + time_stamp = torch.ones([trajectory_rois.shape[0],trajectory_rois.shape[1],trajectory_rois.shape[2],1]).cuda() + for i in range(time_stamp.shape[1]): + time_stamp[:,i,:] = i*0.1 + + box_seq = torch.cat([trajectory_rois[:,:,:,:7],time_stamp],-1) + + box_seq[:, :, :,0:3] = box_seq[:, :, :,0:3] - box_seq[:, 0:1, :, 0:3] + + roi_ry = box_seq[:,:,:,6] % (2 * np.pi) + roi_ry_t0 = roi_ry[:,0] + roi_ry_t0 = roi_ry_t0.repeat(1,box_seq.shape[1]) + + + box_seq = common_utils.rotate_points_along_z( + points=box_seq.view(-1, 1, box_seq.shape[-1]), angle=-roi_ry_t0.view(-1) + ).view(box_seq.shape[0],box_seq.shape[1], -1, box_seq.shape[-1]) + + box_seq[:, :, :, 6] = 0 + + batch_rcnn = box_seq.shape[0]*box_seq.shape[2] + + box_reg, box_feat, _ = self.seqboxembed(box_seq.permute(0,2,3,1).contiguous().view(batch_rcnn,box_seq.shape[-1],box_seq.shape[1])) + + return box_reg, box_feat + + def generate_trajectory(self,cur_batch_boxes,proposals_list,batch_dict): + + trajectory_rois = cur_batch_boxes[:,None,:,:].repeat(1,batch_dict['rois'].shape[-2],1,1) + trajectory_rois[:,0,:,:]= cur_batch_boxes + valid_length = torch.zeros([batch_dict['batch_size'],batch_dict['rois'].shape[-2],trajectory_rois.shape[2]]) + valid_length[:,0] = 1 + num_frames = batch_dict['rois'].shape[-2] + for i in range(1,num_frames): + frame = torch.zeros_like(cur_batch_boxes) + frame[:,:,0:2] = trajectory_rois[:,i-1,:,0:2] + trajectory_rois[:,i-1,:,7:9] + frame[:,:,2:] = trajectory_rois[:,i-1,:,2:] + + for bs_idx in range( batch_dict['batch_size']): + iou3d = iou3d_nms_utils.boxes_iou3d_gpu(frame[bs_idx,:,:7], proposals_list[bs_idx,i,:,:7]) + max_overlaps, traj_assignment = torch.max(iou3d, dim=1) + + fg_inds = ((max_overlaps >= 0.5)).nonzero().view(-1) + + valid_length[bs_idx,i,fg_inds] = 1 + + trajectory_rois[bs_idx,i,fg_inds,:] = proposals_list[bs_idx,i,traj_assignment[fg_inds]] + + batch_dict['valid_length'] = valid_length + + return trajectory_rois,valid_length + + def forward(self, batch_dict): + """ + :param input_data: input dict + :return: + """ + + batch_dict['rois'] = batch_dict['proposals_list'].permute(0,2,1,3) + num_rois = batch_dict['rois'].shape[1] + batch_dict['num_frames'] = batch_dict['rois'].shape[2] + batch_dict['roi_scores'] = batch_dict['roi_scores'].permute(0,2,1) + batch_dict['roi_labels'] = batch_dict['roi_labels'][:,0,:].long() + proposals_list = batch_dict['proposals_list'] + batch_size = batch_dict['batch_size'] + cur_batch_boxes = copy.deepcopy(batch_dict['rois'].detach())[:,:,0] + batch_dict['cur_frame_idx'] = 0 + + trajectory_rois,valid_length = self.generate_trajectory(cur_batch_boxes,proposals_list,batch_dict) + + batch_dict['traj_memory'] = trajectory_rois + batch_dict['has_class_labels'] = True + batch_dict['trajectory_rois'] = trajectory_rois + + if self.training: + targets_dict = self.assign_targets(batch_dict) + batch_dict['rois'] = targets_dict['rois'] + batch_dict['roi_scores'] = targets_dict['roi_scores'] + batch_dict['roi_labels'] = targets_dict['roi_labels'] + targets_dict['trajectory_rois'][:,batch_dict['cur_frame_idx'],:,:] = batch_dict['rois'] + trajectory_rois = targets_dict['trajectory_rois'] + valid_length = targets_dict['valid_length'] + empty_mask = batch_dict['rois'][:,:,:6].sum(-1)==0 + + else: + empty_mask = batch_dict['rois'][:,:,0,:6].sum(-1)==0 + batch_dict['valid_traj_mask'] = ~empty_mask + + rois = batch_dict['rois'] + num_rois = batch_dict['rois'].shape[1] + num_sample = self.num_lidar_points + src = rois.new_zeros(batch_size, num_rois, num_sample, 5) + + src = self.crop_current_frame_points(src, batch_size, trajectory_rois, num_rois,batch_dict) + + src = self.crop_previous_frame_points(src, batch_size,trajectory_rois, num_rois,valid_length,batch_dict) + + src = src.view(batch_size * num_rois, -1, src.shape[-1]) + + src_geometry_feature,proxy_points = self.get_proposal_aware_geometry_feature(src,batch_size,trajectory_rois,num_rois,batch_dict) + + src_motion_feature = self.get_proposal_aware_motion_feature(proxy_points,batch_size,trajectory_rois,num_rois,batch_dict) + + src = src_geometry_feature + src_motion_feature + + box_reg, feat_box = self.trajectories_auxiliary_branch(trajectory_rois) + + if self.model_cfg.get('USE_TRAJ_EMPTY_MASK',None): + src[empty_mask.view(-1)] = 0 + + if self.model_cfg.Transformer.use_grid_pos.init_type == 'index': + pos = self.grid_pos_embeded(self.grid_index.cuda())[None,:,:] + pos = torch.cat([torch.zeros(1,1,self.hidden_dim).cuda(),pos],1) + else: + pos=None + + hs, tokens = self.transformer(src,pos=pos) + point_cls_list = [] + point_reg_list = [] + + for i in range(3): + point_cls_list.append(self.class_embed[0](tokens[i][0])) + + for i in range(hs.shape[0]): + for j in range(3): + point_reg_list.append(self.bbox_embed[i](tokens[j][i])) + + point_cls = torch.cat(point_cls_list,0) + + point_reg = torch.cat(point_reg_list,0) + hs = hs.permute(1,0,2).reshape(hs.shape[1],-1) + + joint_reg = self.jointembed(torch.cat([hs,feat_box],-1)) + + rcnn_cls = point_cls + rcnn_reg = joint_reg + + if not self.training: + batch_dict['rois'] = batch_dict['rois'][:,:,0].contiguous() + rcnn_cls = rcnn_cls[-rcnn_cls.shape[0]//self.num_enc_layer:] + batch_cls_preds, batch_box_preds = self.generate_predicted_boxes( + batch_size=batch_dict['batch_size'], rois=batch_dict['rois'], cls_preds=rcnn_cls, box_preds=rcnn_reg + ) + + batch_dict['batch_box_preds'] = batch_box_preds + + batch_dict['cls_preds_normalized'] = False + if self.avg_stage1_score: + stage1_score = batch_dict['roi_scores'][:,:,:1] + batch_cls_preds = F.sigmoid(batch_cls_preds) + if self.model_cfg.get('IOU_WEIGHT', None): + batch_box_preds_list = [] + roi_labels_list = [] + batch_cls_preds_list = [] + for bs_idx in range(batch_size): + car_mask = batch_dict['roi_labels'][bs_idx] ==1 + batch_cls_preds_car = batch_cls_preds[bs_idx].pow(self.model_cfg.IOU_WEIGHT[0])* \ + stage1_score[bs_idx].pow(1-self.model_cfg.IOU_WEIGHT[0]) + batch_cls_preds_car = batch_cls_preds_car[car_mask][None] + batch_cls_preds_pedcyc = batch_cls_preds[bs_idx].pow(self.model_cfg.IOU_WEIGHT[1])* \ + stage1_score[bs_idx].pow(1-self.model_cfg.IOU_WEIGHT[1]) + batch_cls_preds_pedcyc = batch_cls_preds_pedcyc[~car_mask][None] + cls_preds = torch.cat([batch_cls_preds_car,batch_cls_preds_pedcyc],1) + box_preds = torch.cat([batch_dict['batch_box_preds'][bs_idx][car_mask], + batch_dict['batch_box_preds'][bs_idx][~car_mask]],0)[None] + roi_labels = torch.cat([batch_dict['roi_labels'][bs_idx][car_mask], + batch_dict['roi_labels'][bs_idx][~car_mask]],0)[None] + batch_box_preds_list.append(box_preds) + roi_labels_list.append(roi_labels) + batch_cls_preds_list.append(cls_preds) + batch_dict['batch_box_preds'] = torch.cat(batch_box_preds_list,0) + batch_dict['roi_labels'] = torch.cat(roi_labels_list,0) + batch_cls_preds = torch.cat(batch_cls_preds_list,0) + + else: + batch_cls_preds = torch.sqrt(batch_cls_preds*stage1_score) + batch_dict['cls_preds_normalized'] = True + + batch_dict['batch_cls_preds'] = batch_cls_preds + + + else: + targets_dict['batch_size'] = batch_size + targets_dict['rcnn_cls'] = rcnn_cls + targets_dict['rcnn_reg'] = rcnn_reg + targets_dict['box_reg'] = box_reg + targets_dict['point_reg'] = point_reg + targets_dict['point_cls'] = point_cls + self.forward_ret_dict = targets_dict + + return batch_dict + + def get_loss(self, tb_dict=None): + tb_dict = {} if tb_dict is None else tb_dict + rcnn_loss = 0 + rcnn_loss_cls, cls_tb_dict = self.get_box_cls_layer_loss(self.forward_ret_dict) + rcnn_loss += rcnn_loss_cls + tb_dict.update(cls_tb_dict) + + rcnn_loss_reg, reg_tb_dict = self.get_box_reg_layer_loss(self.forward_ret_dict) + rcnn_loss += rcnn_loss_reg + tb_dict.update(reg_tb_dict) + tb_dict['rcnn_loss'] = rcnn_loss.item() + return rcnn_loss, tb_dict + + def get_box_reg_layer_loss(self, forward_ret_dict): + loss_cfgs = self.model_cfg.LOSS_CONFIG + code_size = self.box_coder.code_size + reg_valid_mask = forward_ret_dict['reg_valid_mask'].view(-1) + batch_size = forward_ret_dict['batch_size'] + + gt_boxes3d_ct = forward_ret_dict['gt_of_rois'][..., 0:code_size] + gt_of_rois_src = forward_ret_dict['gt_of_rois_src'][..., 0:code_size].view(-1, code_size) + + rcnn_reg = forward_ret_dict['rcnn_reg'] + + roi_boxes3d = forward_ret_dict['rois'] + + rcnn_batch_size = gt_boxes3d_ct.view(-1, code_size).shape[0] + + fg_mask = (reg_valid_mask > 0) + fg_sum = fg_mask.long().sum().item() + + tb_dict = {} + + if loss_cfgs.REG_LOSS == 'smooth-l1': + + rois_anchor = roi_boxes3d.clone().detach()[:,:,:7].contiguous().view(-1, code_size) + rois_anchor[:, 0:3] = 0 + rois_anchor[:, 6] = 0 + reg_targets = self.box_coder.encode_torch( + gt_boxes3d_ct.view(rcnn_batch_size, code_size), rois_anchor + ) + rcnn_loss_reg = self.reg_loss_func( + rcnn_reg.view(rcnn_batch_size, -1).unsqueeze(dim=0), + reg_targets.unsqueeze(dim=0), + ) # [B, M, 7] + rcnn_loss_reg = (rcnn_loss_reg.view(rcnn_batch_size, -1) * fg_mask.unsqueeze(dim=-1).float()).sum() / max(fg_sum, 1) + rcnn_loss_reg = rcnn_loss_reg * loss_cfgs.LOSS_WEIGHTS['rcnn_reg_weight']*loss_cfgs.LOSS_WEIGHTS['traj_reg_weight'][0] + + tb_dict['rcnn_loss_reg'] = rcnn_loss_reg.item() + + if self.model_cfg.USE_AUX_LOSS: + point_reg = forward_ret_dict['point_reg'] + + groups = point_reg.shape[0]//reg_targets.shape[0] + if groups != 1 : + point_loss_regs = 0 + slice = reg_targets.shape[0] + for i in range(groups): + point_loss_reg = self.reg_loss_func( + point_reg[i*slice:(i+1)*slice].view(slice, -1).unsqueeze(dim=0),reg_targets.unsqueeze(dim=0),) + point_loss_reg = (point_loss_reg.view(slice, -1) * fg_mask.unsqueeze(dim=-1).float()).sum() / max(fg_sum, 1) + point_loss_reg = point_loss_reg * loss_cfgs.LOSS_WEIGHTS['rcnn_reg_weight']*loss_cfgs.LOSS_WEIGHTS['traj_reg_weight'][2] + + point_loss_regs += point_loss_reg + point_loss_regs = point_loss_regs / groups + tb_dict['point_loss_reg'] = point_loss_regs.item() + rcnn_loss_reg += point_loss_regs + + else: + point_loss_reg = self.reg_loss_func(point_reg.view(rcnn_batch_size, -1).unsqueeze(dim=0),reg_targets.unsqueeze(dim=0),) + point_loss_reg = (point_loss_reg.view(rcnn_batch_size, -1) * fg_mask.unsqueeze(dim=-1).float()).sum() / max(fg_sum, 1) + point_loss_reg = point_loss_reg * loss_cfgs.LOSS_WEIGHTS['rcnn_reg_weight']*loss_cfgs.LOSS_WEIGHTS['traj_reg_weight'][2] + tb_dict['point_loss_reg'] = point_loss_reg.item() + rcnn_loss_reg += point_loss_reg + + seqbox_reg = forward_ret_dict['box_reg'] + seqbox_loss_reg = self.reg_loss_func(seqbox_reg.view(rcnn_batch_size, -1).unsqueeze(dim=0),reg_targets.unsqueeze(dim=0),) + seqbox_loss_reg = (seqbox_loss_reg.view(rcnn_batch_size, -1) * fg_mask.unsqueeze(dim=-1).float()).sum() / max(fg_sum, 1) + seqbox_loss_reg = seqbox_loss_reg * loss_cfgs.LOSS_WEIGHTS['rcnn_reg_weight']*loss_cfgs.LOSS_WEIGHTS['traj_reg_weight'][1] + tb_dict['seqbox_loss_reg'] = seqbox_loss_reg.item() + rcnn_loss_reg += seqbox_loss_reg + + if loss_cfgs.CORNER_LOSS_REGULARIZATION and fg_sum > 0: + + fg_rcnn_reg = rcnn_reg.view(rcnn_batch_size, -1)[fg_mask] + fg_roi_boxes3d = roi_boxes3d[:,:,:7].contiguous().view(-1, code_size)[fg_mask] + + fg_roi_boxes3d = fg_roi_boxes3d.view(1, -1, code_size) + batch_anchors = fg_roi_boxes3d.clone().detach() + roi_ry = fg_roi_boxes3d[:, :, 6].view(-1) + roi_xyz = fg_roi_boxes3d[:, :, 0:3].view(-1, 3) + batch_anchors[:, :, 0:3] = 0 + rcnn_boxes3d = self.box_coder.decode_torch( + fg_rcnn_reg.view(batch_anchors.shape[0], -1, code_size), batch_anchors + ).view(-1, code_size) + + rcnn_boxes3d = common_utils.rotate_points_along_z( + rcnn_boxes3d.unsqueeze(dim=1), roi_ry + ).squeeze(dim=1) + rcnn_boxes3d[:, 0:3] += roi_xyz + + corner_loss_func = loss_utils.get_corner_loss_lidar + + loss_corner = corner_loss_func( + rcnn_boxes3d[:, 0:7], + gt_of_rois_src[fg_mask][:, 0:7]) + + loss_corner = loss_corner.mean() + loss_corner = loss_corner * loss_cfgs.LOSS_WEIGHTS['rcnn_corner_weight'] + + rcnn_loss_reg += loss_corner + tb_dict['rcnn_loss_corner'] = loss_corner.item() + + else: + raise NotImplementedError + + return rcnn_loss_reg, tb_dict + + def get_box_cls_layer_loss(self, forward_ret_dict): + loss_cfgs = self.model_cfg.LOSS_CONFIG + rcnn_cls = forward_ret_dict['rcnn_cls'] + rcnn_cls_labels = forward_ret_dict['rcnn_cls_labels'].view(-1) + + if loss_cfgs.CLS_LOSS == 'BinaryCrossEntropy': + + rcnn_cls_flat = rcnn_cls.view(-1) + + groups = rcnn_cls_flat.shape[0] // rcnn_cls_labels.shape[0] + if groups != 1: + rcnn_loss_cls = 0 + slice = rcnn_cls_labels.shape[0] + for i in range(groups): + batch_loss_cls = F.binary_cross_entropy(torch.sigmoid(rcnn_cls_flat[i*slice:(i+1)*slice]), + rcnn_cls_labels.float(), reduction='none') + + cls_valid_mask = (rcnn_cls_labels >= 0).float() + rcnn_loss_cls = rcnn_loss_cls + (batch_loss_cls * cls_valid_mask).sum() / torch.clamp(cls_valid_mask.sum(), min=1.0) + + rcnn_loss_cls = rcnn_loss_cls / groups + + else: + + batch_loss_cls = F.binary_cross_entropy(torch.sigmoid(rcnn_cls_flat), rcnn_cls_labels.float(), reduction='none') + cls_valid_mask = (rcnn_cls_labels >= 0).float() + rcnn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp(cls_valid_mask.sum(), min=1.0) + + + elif loss_cfgs.CLS_LOSS == 'CrossEntropy': + batch_loss_cls = F.cross_entropy(rcnn_cls, rcnn_cls_labels, reduction='none', ignore_index=-1) + cls_valid_mask = (rcnn_cls_labels >= 0).float() + rcnn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp(cls_valid_mask.sum(), min=1.0) + + else: + raise NotImplementedError + + rcnn_loss_cls = rcnn_loss_cls * loss_cfgs.LOSS_WEIGHTS['rcnn_cls_weight'] + + tb_dict = {'rcnn_loss_cls': rcnn_loss_cls.item()} + return rcnn_loss_cls, tb_dict + + + def generate_predicted_boxes(self, batch_size, rois, cls_preds=None, box_preds=None): + """ + Args: + batch_size: + rois: (B, N, 7) + cls_preds: (BN, num_class) + box_preds: (BN, code_size) + Returns: + """ + code_size = self.box_coder.code_size + if cls_preds is not None: + batch_cls_preds = cls_preds.view(batch_size, -1, cls_preds.shape[-1]) + else: + batch_cls_preds = None + batch_box_preds = box_preds.view(batch_size, -1, code_size) + + roi_ry = rois[:, :, 6].view(-1) + roi_xyz = rois[:, :, 0:3].view(-1, 3) + local_rois = rois.clone().detach() + local_rois[:, :, 0:3] = 0 + + batch_box_preds = self.box_coder.decode_torch(batch_box_preds, local_rois).view(-1, code_size) + + batch_box_preds = common_utils.rotate_points_along_z( + batch_box_preds.unsqueeze(dim=1), roi_ry + ).squeeze(dim=1) + + batch_box_preds[:, 0:3] += roi_xyz + batch_box_preds = batch_box_preds.view(batch_size, -1, code_size) + batch_box_preds = torch.cat([batch_box_preds,rois[:,:,7:]],-1) + return batch_cls_preds, batch_box_preds diff --git a/pcdet/models/roi_heads/mppnet_memory_bank_e2e.py b/pcdet/models/roi_heads/mppnet_memory_bank_e2e.py new file mode 100644 index 000000000..095f1e29a --- /dev/null +++ b/pcdet/models/roi_heads/mppnet_memory_bank_e2e.py @@ -0,0 +1,594 @@ +from typing import ValuesView +import torch.nn as nn +import torch +import numpy as np +import copy +import torch.nn.functional as F +from pcdet.ops.iou3d_nms import iou3d_nms_utils +from ...utils import common_utils, loss_utils +from .roi_head_template import RoIHeadTemplate +from ..model_utils.mppnet_utils import build_transformer, PointNet, MLP +from .target_assigner.proposal_target_layer import ProposalTargetLayer +from pcdet.ops.pointnet2.pointnet2_stack import pointnet2_modules as pointnet2_stack_modules + + +class MPPNetHeadE2E(RoIHeadTemplate): + def __init__(self,model_cfg, num_class=1,**kwargs): + super().__init__(num_class=num_class, model_cfg=model_cfg) + self.model_cfg = model_cfg + self.use_time_stamp = self.model_cfg.get('USE_TIMESTAMP',None) + self.num_lidar_points = self.model_cfg.Transformer.num_lidar_points + self.avg_stage1_score = self.model_cfg.get('AVG_STAGE1_SCORE', None) + + self.nhead = model_cfg.Transformer.nheads + self.num_enc_layer = model_cfg.Transformer.enc_layers + hidden_dim = model_cfg.TRANS_INPUT + self.hidden_dim = model_cfg.TRANS_INPUT + self.num_groups = model_cfg.Transformer.num_groups + + self.grid_size = model_cfg.ROI_GRID_POOL.GRID_SIZE + self.num_proxy_points = model_cfg.Transformer.num_proxy_points + + self.seqboxembed = PointNet(8,model_cfg=self.model_cfg) + self.jointembed = MLP(self.hidden_dim*(self.num_groups+1), model_cfg.Transformer.hidden_dim, self.box_coder.code_size * self.num_class, 4) + + num_radius = len(self.model_cfg.ROI_GRID_POOL.POOL_RADIUS) + self.up_dimension_geometry = MLP(input_dim = 29, hidden_dim = 64, output_dim =hidden_dim//num_radius, num_layers = 3) + self.up_dimension_motion = MLP(input_dim = 30, hidden_dim = 64, output_dim = hidden_dim, num_layers = 3) + + self.transformer = build_transformer(model_cfg.Transformer) + + self.roi_grid_pool_layer = pointnet2_stack_modules.StackSAModuleMSG( + radii=self.model_cfg.ROI_GRID_POOL.POOL_RADIUS, + nsamples=self.model_cfg.ROI_GRID_POOL.NSAMPLE, + mlps=self.model_cfg.ROI_GRID_POOL.MLPS, + use_xyz=True, + pool_method=self.model_cfg.ROI_GRID_POOL.POOL_METHOD, + ) + + self.class_embed = nn.ModuleList() + self.class_embed.append(nn.Linear(model_cfg.Transformer.hidden_dim, 1)) + + self.bbox_embed = nn.ModuleList() + for _ in range(self.num_groups): + self.bbox_embed.append(MLP(model_cfg.Transformer.hidden_dim, model_cfg.Transformer.hidden_dim, self.box_coder.code_size * self.num_class, 4)) + + if self.model_cfg.Transformer.use_grid_pos.enabled: + if self.model_cfg.Transformer.use_grid_pos.init_type == 'index': + self.grid_index = torch.cat([i.reshape(-1,1)for i in torch.meshgrid(torch.arange(self.grid_size), torch.arange(self.grid_size), torch.arange(self.grid_size))],1).float().cuda() + self.grid_pos_embeded = MLP(input_dim = 3, hidden_dim = 256, output_dim = hidden_dim, num_layers = 2) + else: + self.pos = nn.Parameter(torch.zeros(1, self.num_grid_points, 256)) + + def init_weights(self, weight_init='xavier'): + if weight_init == 'kaiming': + init_func = nn.init.kaiming_normal_ + elif weight_init == 'xavier': + init_func = nn.init.xavier_normal_ + elif weight_init == 'normal': + init_func = nn.init.normal_ + else: + raise NotImplementedError + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): + if weight_init == 'normal': + init_func(m.weight, mean=0, std=0.001) + else: + init_func(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + nn.init.normal_(self.bbox_embed.layers[-1].weight, mean=0, std=0.001) + + def get_corner_points_of_roi(self, rois): + rois = rois.view(-1, rois.shape[-1]) + batch_size_rcnn = rois.shape[0] + + local_roi_grid_points = self.get_corner_points(rois, batch_size_rcnn) + local_roi_grid_points = common_utils.rotate_points_along_z( + local_roi_grid_points.clone(), rois[:, 6] + ).squeeze(dim=1) + global_center = rois[:, 0:3].clone() + + + global_roi_grid_points = local_roi_grid_points + global_center.unsqueeze(dim=1) + return global_roi_grid_points, local_roi_grid_points + + @staticmethod + def get_dense_grid_points(rois, batch_size_rcnn, grid_size): + if isinstance(grid_size,list): + faked_features = rois.new_ones((grid_size[0], grid_size[1], grid_size[2])) + grid_size = torch.tensor(grid_size).float().cuda() + else: + faked_features = rois.new_ones((grid_size, grid_size, grid_size)) + dense_idx = faked_features.nonzero() + dense_idx = dense_idx.repeat(batch_size_rcnn, 1, 1).float() + + local_roi_size = rois.view(batch_size_rcnn, -1)[:, 3:6] + roi_grid_points = torch.div((dense_idx + 0.5), grid_size) * local_roi_size.unsqueeze(dim=1) - (local_roi_size.unsqueeze(dim=1) / 2) + return roi_grid_points + + @staticmethod + def get_corner_points(rois, batch_size_rcnn): + faked_features = rois.new_ones((2, 2, 2)) + + dense_idx = faked_features.nonzero() + dense_idx = dense_idx.repeat(batch_size_rcnn, 1, 1).float() + + local_roi_size = rois.view(batch_size_rcnn, -1)[:, 3:6] + roi_grid_points = dense_idx * local_roi_size.unsqueeze(dim=1) \ + - (local_roi_size.unsqueeze(dim=1) / 2) + return roi_grid_points + + def get_proxy_points_of_roi(self, rois, grid_size): + rois = rois.view(-1, rois.shape[-1]) + batch_size_rcnn = rois.shape[0] + + local_roi_grid_points = self.get_dense_grid_points(rois, batch_size_rcnn, grid_size) + local_roi_grid_points = common_utils.rotate_points_along_z(local_roi_grid_points.clone(), rois[:, 6]).squeeze(dim=1) + global_center = rois[:, 0:3].clone() + global_roi_grid_points = local_roi_grid_points + global_center.unsqueeze(dim=1) + return global_roi_grid_points, local_roi_grid_points + + def roi_grid_pool(self, batch_size, rois, point_coords, point_features,batch_dict=None,batch_cnt=None): + """ + Args: + batch_dict: + batch_size: + rois: (B, num_rois, 7 + C) + point_coords: (num_points, 4) [bs_idx, x, y, z] + point_features: (num_points, C) + point_cls_scores: (N1 + N2 + N3 + ..., 1) + point_part_offset: (N1 + N2 + N3 + ..., 3) + Returns: + """ + + num_frames = batch_dict['num_frames'] + num_rois = rois.shape[2]*rois.shape[1] + + global_roi_proxy_points, local_roi_proxy_points = self.get_proxy_points_of_roi( + rois.permute(0,2,1,3).contiguous(), grid_size=self.grid_size + ) + + num_points = point_coords.shape[1] + num_proxy_points = self.num_proxy_points + + xyz = point_coords[:, :, 0:3].view(-1,3) + if batch_cnt is None: + xyz_batch_cnt = torch.tensor([num_points]*rois.shape[2]*batch_size).cuda().int() + else: + xyz_batch_cnt = torch.tensor(batch_cnt).cuda().int() + new_xyz = torch.cat([i[0] for i in global_roi_proxy_points.chunk(rois.shape[2],0)],0) + new_xyz_batch_cnt = torch.tensor([self.num_proxy_points]*rois.shape[2]*batch_size).cuda().int() + + _, pooled_features = self.roi_grid_pool_layer( + xyz=xyz.contiguous(), + xyz_batch_cnt=xyz_batch_cnt, + new_xyz=new_xyz, + new_xyz_batch_cnt=new_xyz_batch_cnt, + features=point_features.view(-1,point_features.shape[-1]).contiguous(), + ) + + features = pooled_features.view( + point_features.shape[0], self.num_proxy_points, + pooled_features.shape[-1] + ).contiguous() + + return features,global_roi_proxy_points.view(batch_size*rois.shape[2], num_frames*num_proxy_points,3).contiguous() + + def spherical_coordinate(self, src, diag_dist): + + assert (src.shape[-1] == 27) + device = src.device + indices_x = torch.LongTensor([0,3,6,9,12,15,18,21,24]).to(device) # + indices_y = torch.LongTensor([1,4,7,10,13,16,19,22,25]).to(device) # + indices_z = torch.LongTensor([2,5,8,11,14,17,20,23,26]).to(device) + src_x = torch.index_select(src, -1, indices_x) + src_y = torch.index_select(src, -1, indices_y) + src_z = torch.index_select(src, -1, indices_z) + dis = (src_x ** 2 + src_y ** 2 + src_z ** 2) ** 0.5 + phi = torch.atan(src_y / (src_x + 1e-5)) + the = torch.acos(src_z / (dis + 1e-5)) + dis = dis / (diag_dist + 1e-5) + src = torch.cat([dis, phi, the], dim = -1) + return src + + def crop_current_frame_points(self, src, batch_size,trajectory_rois,num_rois,num_sample, batch_dict): + + for bs_idx in range(batch_size): + + cur_batch_boxes = trajectory_rois[bs_idx,0,:,:7].view(-1,7) + cur_radiis = torch.sqrt((cur_batch_boxes[:,3]/2) ** 2 + (cur_batch_boxes[:,4]/2) ** 2) * 1.1 + cur_points = batch_dict['points'][(batch_dict['points'][:, 0] == bs_idx)][:,1:] + time_mask = cur_points[:,-1].abs() < 1e-3 + cur_points = cur_points[time_mask] + dis = torch.norm((cur_points[:,:2].unsqueeze(0) - cur_batch_boxes[:,:2].unsqueeze(1).repeat(1,cur_points.shape[0],1)), dim = 2) + point_mask = (dis <= cur_radiis.unsqueeze(-1)) + + mask = point_mask + sampled_idx = torch.topk(mask.float(),128)[1] + sampled_idx_buffer = sampled_idx[:, 0:1].repeat(1, 128) + roi_idx = torch.arange(num_rois)[:, None].repeat(1, 128) + sampled_mask = mask[roi_idx, sampled_idx] + sampled_idx_buffer[sampled_mask] = sampled_idx[sampled_mask] + + src[bs_idx] = cur_points[sampled_idx_buffer][:,:,:5] + empty_flag = sampled_mask.sum(-1)==0 + src[bs_idx,empty_flag] = 0 + + return src + + def trajectories_auxiliary_branch(self,trajectory_rois): + + time_stamp = torch.ones([trajectory_rois.shape[0],trajectory_rois.shape[1],trajectory_rois.shape[2],1]).cuda() + for i in range(time_stamp.shape[1]): + time_stamp[:,i,:] = i*0.1 + + box_seq = torch.cat([trajectory_rois[:,:,:,:7],time_stamp],-1) + # box_seq_time = box_seq + + if self.model_cfg.USE_BOX_ENCODING.NORM_T0: + # canonical transformation + box_seq[:, :, :,0:3] = box_seq[:, :, :,0:3] - box_seq[:, 0:1, :, 0:3] + + + roi_ry = box_seq[:,:,:,6] % (2 * np.pi) + roi_ry_t0 = roi_ry[:,0] + roi_ry_t0 = roi_ry_t0.repeat(1,box_seq.shape[1]) + + # transfer LiDAR coords to local coords + box_seq = common_utils.rotate_points_along_z( + points=box_seq.view(-1, 1, box_seq.shape[-1]), angle=-roi_ry_t0.view(-1) + ).view(box_seq.shape[0],box_seq.shape[1], -1, box_seq.shape[-1]) + + if self.model_cfg.USE_BOX_ENCODING.ALL_YAW_T0: + box_seq[:, :, :, 6] = 0 + + else: + box_seq[:, 0:1, :, 6] = 0 + box_seq[:, 1:, :, 6] = roi_ry[:, 1:, ] - roi_ry[:,0:1] + + + batch_rcnn = box_seq.shape[0]*box_seq.shape[2] + + + box_reg, box_feat, _ = self.seqboxembed(box_seq.permute(0,2,3,1).contiguous().view(batch_rcnn,box_seq.shape[-1],box_seq.shape[1])) + + return box_reg, box_feat + + def get_proposal_aware_motion_feature(self,proxy_point,batch_size,trajectory_rois,num_rois,batch_dict): + + time_stamp = torch.ones([proxy_point.shape[0],proxy_point.shape[1],1]).cuda() + padding_zero = torch.zeros([proxy_point.shape[0],proxy_point.shape[1],2]).cuda() + proxy_point_padding = torch.cat([padding_zero,time_stamp],-1) + + num_time_coding = trajectory_rois.shape[1] + + for i in range(num_time_coding): + proxy_point_padding[:,i*self.num_proxy_points:(i+1)*self.num_proxy_points,-1] = i*0.1 + + + ######### use T0 Norm ######## + corner_points, _ = self.get_corner_points_of_roi(trajectory_rois[:,0,:,:].contiguous()) + corner_points = corner_points.view(batch_size, num_rois, -1, corner_points.shape[-1]) + corner_points = corner_points.view(batch_size * num_rois, -1) + corner_add_center_points = torch.cat([corner_points, trajectory_rois[:,0,:,:].reshape(batch_size * num_rois, -1)[:,:3]], dim = -1) + + pos_fea = proxy_point[:,:,:3].repeat(1,1,9) - corner_add_center_points.unsqueeze(1) + + lwh = trajectory_rois[:,0,:,:].reshape(batch_size * num_rois, -1)[:,3:6].unsqueeze(1).repeat(1,proxy_point.shape[1],1) + diag_dist = (lwh[:,:,0]**2 + lwh[:,:,1]**2 + lwh[:,:,2]**2) ** 0.5 + pos_fea = self.spherical_coordinate(pos_fea, diag_dist = diag_dist.unsqueeze(-1)) + ######### use T0 Norm ######## + + proxy_point_padding = torch.cat([pos_fea,proxy_point_padding],-1) + proxy_point_motion_feat = self.up_dimension_motion(proxy_point_padding) + + return proxy_point_motion_feat + + def get_proposal_aware_geometry_feature(self,src, batch_size,trajectory_rois,num_rois,batch_dict): + + i = 0 # only current frame + corner_points, _ = self.get_corner_points_of_roi(trajectory_rois[:,i,:,:].contiguous()) + + corner_points = corner_points.view(batch_size, num_rois, -1, corner_points.shape[-1]) + corner_points = corner_points.view(batch_size * num_rois, -1) + trajectory_roi_center = trajectory_rois[:,i,:,:].contiguous().reshape(batch_size * num_rois, -1)[:,:3] + corner_add_center_points = torch.cat([corner_points, trajectory_roi_center], dim = -1) + proposal_aware_feat = src[:,i*self.num_lidar_points:(i+1)*self.num_lidar_points,:3].repeat(1,1,9) - \ + corner_add_center_points.unsqueeze(1).repeat(1,self.num_lidar_points,1) + + lwh = trajectory_rois[:,i,:,:].reshape(batch_size * num_rois, -1)[:,3:6].unsqueeze(1).repeat(1,proposal_aware_feat.shape[1],1) + diag_dist = (lwh[:,:,0]**2 + lwh[:,:,1]**2 + lwh[:,:,2]**2) ** 0.5 + proposal_aware_feat = self.spherical_coordinate(proposal_aware_feat, diag_dist = diag_dist.unsqueeze(-1)) + + proposal_aware_feat = torch.cat([proposal_aware_feat, src[:,:,3:]], dim = -1) + src_gemoetry = self.up_dimension_geometry(proposal_aware_feat) + proxy_point_geometry, proxy_points = self.roi_grid_pool(batch_size,trajectory_rois,src,src_gemoetry,batch_dict,batch_cnt=None) + return proxy_point_geometry,proxy_points + + @staticmethod + def reorder_rois_for_refining(pred_bboxes): + + num_max_rois = max([len(bbox) for bbox in pred_bboxes]) + num_max_rois = max(1, num_max_rois) # at least one faked rois to avoid error + ordered_bboxes = torch.zeros([len(pred_bboxes),num_max_rois,pred_bboxes[0].shape[-1]]).cuda() + + for bs_idx in range(ordered_bboxes.shape[0]): + ordered_bboxes[bs_idx,:len(pred_bboxes[bs_idx])] = pred_bboxes[bs_idx] + return ordered_bboxes + + def transform_prebox_to_current_vel(self,pred_boxes3d,pose_pre,pose_cur): + + expand_bboxes = np.concatenate([pred_boxes3d[:,:3], np.ones((pred_boxes3d.shape[0], 1))], axis=-1) + expand_vels = np.concatenate([pred_boxes3d[:,7:9], np.zeros((pred_boxes3d.shape[0], 1))], axis=-1) + bboxes_global = np.dot(expand_bboxes, pose_pre.T)[:, :3] + vels_global = np.dot(expand_vels, pose_pre[:3,:3].T) + moved_bboxes_global = copy.deepcopy(bboxes_global) + moved_bboxes_global[:,:2] = moved_bboxes_global[:,:2] - 0.1*vels_global[:,:2] + + expand_bboxes_global = np.concatenate([bboxes_global[:,:3],np.ones((bboxes_global.shape[0], 1))], axis=-1) + expand_moved_bboxes_global = np.concatenate([moved_bboxes_global[:,:3],np.ones((bboxes_global.shape[0], 1))], axis=-1) + bboxes_pre2cur = np.dot(expand_bboxes_global, np.linalg.inv(pose_cur.T))[:, :3] + + moved_bboxes_pre2cur = np.dot(expand_moved_bboxes_global, np.linalg.inv(pose_cur.T))[:, :3] + vels_pre2cur = np.dot(vels_global, np.linalg.inv(pose_cur[:3,:3].T))[:,:2] + bboxes_pre2cur = np.concatenate([bboxes_pre2cur, pred_boxes3d[:,3:7],vels_pre2cur],axis=-1) + bboxes_pre2cur[:,6] = bboxes_pre2cur[..., 6] + np.arctan2(pose_pre[1, 0], pose_pre[0,0]) + bboxes_pre2cur[:,6] = bboxes_pre2cur[..., 6] - np.arctan2(pose_cur[1, 0], pose_cur[0,0]) + bboxes_pre2cur[:,7:9] = moved_bboxes_pre2cur[:,:2] - bboxes_pre2cur[:,:2] + return bboxes_pre2cur[None,:,:] + + def generate_trajectory(self,cur_batch_boxes,proposals_list,batch_dict): + + trajectory_rois = cur_batch_boxes[:,None,:,:].repeat(1,batch_dict['rois'].shape[-2],1,1) + trajectory_rois[:,0,:,:]= cur_batch_boxes + valid_length = torch.zeros([batch_dict['batch_size'],batch_dict['rois'].shape[-2],trajectory_rois.shape[2]]) + valid_length[:,0] = 1 + num_frames = batch_dict['rois'].shape[-2] + matching_table = (trajectory_rois.new_ones([trajectory_rois.shape[1],trajectory_rois.shape[2]]) * -1).long() + + for i in range(1,num_frames): + frame = torch.zeros_like(cur_batch_boxes) + frame[:,:,0:2] = trajectory_rois[:,i-1,:,0:2] + trajectory_rois[:,i-1,:,7:9] + frame[:,:,2:] = trajectory_rois[:,i-1,:,2:] + + for bs_idx in range( batch_dict['batch_size']): + iou3d = iou3d_nms_utils.boxes_iou3d_gpu(frame[bs_idx,:,:7], proposals_list[bs_idx,i,:,:7]) + max_overlaps, traj_assignment = torch.max(iou3d, dim=1) + + fg_inds = ((max_overlaps >= 0.5)).nonzero().view(-1) + + valid_length[bs_idx,i,fg_inds] = 1 + matching_table[i,fg_inds] = traj_assignment[fg_inds] + + trajectory_rois[bs_idx,i,fg_inds,:] = proposals_list[bs_idx,i,traj_assignment[fg_inds]] + + batch_dict['valid_length'] = valid_length + + return trajectory_rois,valid_length, matching_table + + def forward(self, batch_dict): + """ + :param input_data: input dict + :return: + """ + + if 'memory_bank' in batch_dict.keys(): + + rois_list = [] + memory_list = copy.deepcopy(batch_dict['memory_bank']) + + for idx in range(len(memory_list['rois'])): + + rois = torch.cat([batch_dict['memory_bank']['rois'][idx][0], + batch_dict['memory_bank']['roi_scores'][idx][0], + batch_dict['memory_bank']['roi_labels'][idx][0]],-1) + + rois_list.append(rois) + + + batch_rois = self.reorder_rois_for_refining(rois_list) + batch_dict['roi_scores'] = batch_rois[None,:,:,9] + batch_dict['roi_labels'] = batch_rois[None,:,:,10] + + proposals_list = [] + + for i in range(self.model_cfg.Transformer.num_frames): + pose_pre = batch_dict['poses'][0,i*4:(i+1)*4,:] + pred2cur = self.transform_prebox_to_current_vel(batch_rois[i,:,:9].cpu().numpy(),pose_pre=pose_pre.cpu().numpy(), + pose_cur=batch_dict['poses'][0,:4,:].cpu().numpy()) + proposals_list.append(torch.from_numpy(pred2cur).cuda().float()) + batch_rois = torch.cat(proposals_list,0) + batch_dict['proposals_list'] = batch_rois[None,:,:,:9] + + batch_dict['rois'] = batch_rois.unsqueeze(0).permute(0,2,1,3) + num_rois = batch_dict['rois'].shape[1] + batch_dict['num_frames'] = batch_dict['rois'].shape[2] + roi_labels_list = copy.deepcopy(batch_dict['roi_labels']) + + batch_dict['roi_scores'] = batch_dict['roi_scores'].permute(0,2,1) + batch_dict['roi_labels'] = batch_dict['roi_labels'][:,0,:].long() + proposals_list = batch_dict['proposals_list'] + batch_size = batch_dict['batch_size'] + cur_batch_boxes = copy.deepcopy(batch_dict['rois'].detach())[:,:,0] + batch_dict['cur_frame_idx'] = 0 + + else: + + batch_dict['rois'] = batch_dict['proposals_list'].permute(0,2,1,3) + assert batch_dict['rois'].shape[0] ==1 + num_rois = batch_dict['rois'].shape[1] + batch_dict['num_frames'] = batch_dict['rois'].shape[2] + roi_labels_list = copy.deepcopy(batch_dict['roi_labels']) + + batch_dict['roi_scores'] = batch_dict['roi_scores'].permute(0,2,1) + batch_dict['roi_labels'] = batch_dict['roi_labels'][:,0,:].long() + proposals_list = batch_dict['proposals_list'] + batch_size = batch_dict['batch_size'] + cur_batch_boxes = copy.deepcopy(batch_dict['rois'].detach())[:,:,0] + batch_dict['cur_frame_idx'] = 0 + + trajectory_rois,effective_length,matching_table = self.generate_trajectory(cur_batch_boxes,proposals_list,batch_dict) + + + batch_dict['has_class_labels'] = True + batch_dict['trajectory_rois'] = trajectory_rois + + + rois = batch_dict['rois'] + num_rois = batch_dict['rois'].shape[1] + + if self.model_cfg.get('USE_TRAJ_EMPTY_MASK',None): + empty_mask = batch_dict['rois'][:,:,0,:6].sum(-1)==0 + batch_dict['valid_traj_mask'] = ~empty_mask + + num_sample = self.num_lidar_points + + src = rois.new_zeros(batch_size, num_rois, num_sample, 5) + + src = self.crop_current_frame_points(src, batch_size, trajectory_rois, num_rois, num_sample, batch_dict) + + src = src.view(batch_size * num_rois, -1, src.shape[-1]) + + src_geometry_feature,proxy_points = self.get_proposal_aware_geometry_feature(src,batch_size,trajectory_rois,num_rois,batch_dict) + + src_motion_feature = self.get_proposal_aware_motion_feature(proxy_points,batch_size,trajectory_rois,num_rois,batch_dict) + + + if batch_dict['sample_idx'][0] >=1: + + src_repeat = src_geometry_feature[:,None,:self.num_proxy_points,:].repeat([1,trajectory_rois.shape[1],1,1]) + src_before = src_repeat[:,1:,:,:].clone() #[bs,traj,num_roi,C] + valid_length = batch_dict['num_frames'] -1 if batch_dict['sample_idx'][0] > batch_dict['num_frames'] -1 \ + else int(batch_dict['sample_idx'][0].item()) + num_max_rois = max(trajectory_rois.shape[2], *[i.shape[0] for i in batch_dict['memory_bank']['feature_bank']]) + feature_bank = self.reorder_memory(batch_dict['memory_bank']['feature_bank'][:valid_length],num_max_rois) + effective_length = effective_length[0,1:1+valid_length].bool() #rm dim of bs + for i in range(valid_length): + src_before[:,i][effective_length[i]] = feature_bank[i,matching_table[1+i][effective_length[i]]] + + src_geometry_feature = torch.cat([src_repeat[:,:1],src_before],1).view(src_geometry_feature.shape[0],-1, + src_geometry_feature.shape[-1]) + + else: + + src_geometry_feature = src_geometry_feature.repeat([1,trajectory_rois.shape[1],1]) + + batch_dict['geometory_feature_memory'] = src_geometry_feature[:,:self.num_proxy_points] + + + src = src_geometry_feature + src_motion_feature + + + if self.model_cfg.get('USE_TRAJ_EMPTY_MASK',None): + src[empty_mask.view(-1)] = 0 + + if self.model_cfg.Transformer.use_grid_pos.init_type == 'index': + pos = self.grid_pos_embeded(self.grid_index.cuda())[None,:,:] + pos = torch.cat([torch.zeros(1,1,self.hidden_dim).cuda(),pos],1) + else: + pos=None + + hs, tokens = self.transformer(src,pos=pos) + point_cls_list = [] + + for i in range(3): + point_cls_list.append(self.class_embed[0](tokens[i][0])) + + point_cls = torch.cat(point_cls_list,0) + + + hs = hs.permute(1,0,2).reshape(hs.shape[1],-1) + + _, feat_box = self.trajectories_auxiliary_branch(trajectory_rois) + + joint_reg = self.jointembed(torch.cat([hs,feat_box],-1)) + + rcnn_cls = point_cls + rcnn_reg = joint_reg + + if not self.training: + batch_dict['rois'] = batch_dict['rois'][:,:,0].contiguous() + rcnn_cls = rcnn_cls[-rcnn_cls.shape[0]//self.num_enc_layer:] + batch_cls_preds, batch_box_preds = self.generate_predicted_boxes( + batch_size=batch_dict['batch_size'], rois=batch_dict['rois'], cls_preds=rcnn_cls, box_preds=rcnn_reg + ) + + batch_dict['batch_box_preds'] = batch_box_preds + + batch_dict['cls_preds_normalized'] = False + if self.avg_stage1_score: + stage1_score = batch_dict['roi_scores'][:,:,:1] + batch_cls_preds = F.sigmoid(batch_cls_preds) + if self.model_cfg.get('IOU_WEIGHT', None): + batch_box_preds_list = [] + roi_labels_list = [] + batch_cls_preds_list = [] + for bs_idx in range(batch_size): + car_mask = batch_dict['roi_labels'][bs_idx] ==1 + batch_cls_preds_car = batch_cls_preds[bs_idx].pow(self.model_cfg.IOU_WEIGHT[0])* \ + stage1_score[bs_idx].pow(1-self.model_cfg.IOU_WEIGHT[0]) + batch_cls_preds_car = batch_cls_preds_car[car_mask][None] + batch_cls_preds_pedcyc = batch_cls_preds[bs_idx].pow(self.model_cfg.IOU_WEIGHT[1])* \ + stage1_score[bs_idx].pow(1-self.model_cfg.IOU_WEIGHT[1]) + batch_cls_preds_pedcyc = batch_cls_preds_pedcyc[~car_mask][None] + cls_preds = torch.cat([batch_cls_preds_car,batch_cls_preds_pedcyc],1) + box_preds = torch.cat([batch_dict['batch_box_preds'][bs_idx][car_mask], + batch_dict['batch_box_preds'][bs_idx][~car_mask]],0)[None] + roi_labels = torch.cat([batch_dict['roi_labels'][bs_idx][car_mask], + batch_dict['roi_labels'][bs_idx][~car_mask]],0)[None] + batch_box_preds_list.append(box_preds) + roi_labels_list.append(roi_labels) + batch_cls_preds_list.append(cls_preds) + batch_dict['batch_box_preds'] = torch.cat(batch_box_preds_list,0) + batch_dict['roi_labels'] = torch.cat(roi_labels_list,0) + batch_cls_preds = torch.cat(batch_cls_preds_list,0) + + else: + batch_cls_preds = torch.sqrt(batch_cls_preds*stage1_score) + batch_dict['cls_preds_normalized'] = True + + batch_dict['batch_cls_preds'] = batch_cls_preds + + return batch_dict + + def reorder_memory(self, memory,num_max_rois): + + ordered_memory = memory[0].new_zeros([len(memory),num_max_rois,memory[0].shape[1],memory[0].shape[2]]) + for bs_idx in range(len(memory)): + ordered_memory[bs_idx,:len(memory[bs_idx])] = memory[bs_idx] + return ordered_memory + + def generate_predicted_boxes(self, batch_size, rois, cls_preds=None, box_preds=None): + """ + Args: + batch_size: + rois: (B, N, 7) + cls_preds: (BN, num_class) + box_preds: (BN, code_size) + Returns: + """ + code_size = self.box_coder.code_size + + if cls_preds is not None: + batch_cls_preds = cls_preds.view(batch_size, -1, cls_preds.shape[-1]) + else: + batch_cls_preds = None + batch_box_preds = box_preds.view(batch_size, -1, code_size) + + roi_ry = rois[:, :, 6].view(-1) + roi_xyz = rois[:, :, 0:3].view(-1, 3) + local_rois = rois.clone().detach() + local_rois[:, :, 0:3] = 0 + + batch_box_preds = self.box_coder.decode_torch(batch_box_preds, local_rois).view(-1, code_size) + + batch_box_preds = common_utils.rotate_points_along_z( + batch_box_preds.unsqueeze(dim=1), roi_ry + ).squeeze(dim=1) + + batch_box_preds[:, 0:3] += roi_xyz + batch_box_preds = batch_box_preds.view(batch_size, -1, code_size) + batch_box_preds = torch.cat([batch_box_preds,rois[:,:,7:]],-1) + return batch_cls_preds, batch_box_preds \ No newline at end of file diff --git a/pcdet/models/roi_heads/target_assigner/proposal_target_layer.py b/pcdet/models/roi_heads/target_assigner/proposal_target_layer.py index b0d9644bc..49f5f0a04 100644 --- a/pcdet/models/roi_heads/target_assigner/proposal_target_layer.py +++ b/pcdet/models/roi_heads/target_assigner/proposal_target_layer.py @@ -71,7 +71,7 @@ def sample_rois_for_rcnn(self, batch_dict): gt_boxes: (B, N, 7 + C + 1) roi_labels: (B, num_rois) Returns: - + """ batch_size = batch_dict['batch_size'] rois = batch_dict['rois'] @@ -199,9 +199,9 @@ def get_max_iou_with_same_class(rois, roi_labels, gt_boxes, gt_labels): roi_labels: (N) gt_boxes: (N, ) gt_labels: - + Returns: - + """ """ :param rois: (N, 7) @@ -220,7 +220,7 @@ def get_max_iou_with_same_class(rois, roi_labels, gt_boxes, gt_labels): cur_gt = gt_boxes[gt_mask] original_gt_assignment = gt_mask.nonzero().view(-1) - iou3d = iou3d_nms_utils.boxes_iou3d_gpu(cur_roi, cur_gt) # (M, N) + iou3d = iou3d_nms_utils.boxes_iou3d_gpu(cur_roi[:, :7], cur_gt[:, :7]) # (M, N) cur_max_overlaps, cur_gt_assignment = torch.max(iou3d, dim=1) max_overlaps[roi_mask] = cur_max_overlaps gt_assignment[roi_mask] = original_gt_assignment[cur_gt_assignment] diff --git a/tools/cfgs/waymo_models/mppnet_16frame.yaml b/tools/cfgs/waymo_models/mppnet_16frame.yaml new file mode 100644 index 000000000..8b496f1d7 --- /dev/null +++ b/tools/cfgs/waymo_models/mppnet_16frame.yaml @@ -0,0 +1,168 @@ +CLASS_NAMES: ['Vehicle', 'Pedestrian', 'Cyclist'] + +DATA_CONFIG: + _BASE_CONFIG_: cfgs/dataset_configs/waymo_dataset.yaml + PROCESSED_DATA_TAG: 'waymo_processed_data_v0_5_0' + + SAMPLED_INTERVAL: { + 'train': 1, + 'test': 1 + } + FILTER_EMPTY_BOXES_FOR_TRAIN: True + DISABLE_NLZ_FLAG_ON_POINTS: True + + SEQUENCE_CONFIG: + ENABLED: True + SAMPLE_OFFSET: [-15,0] + + USE_PREDBOX: True + ROI_BOXES_PATH: { + 'train': '../output/xxxxx/train/result.pkl', # example: predicted boxes of RPN in training set + 'test': '../output/xxxxx/val/result.pkl', # example: predicted boxes of RPN in evalulation set + } + + DATA_AUGMENTOR: + DISABLE_AUG_LIST: [ 'placeholder' ] + AUG_CONFIG_LIST: + - NAME: random_world_flip + ALONG_AXIS_LIST: [ 'x', 'y' ] + + - NAME: random_world_rotation + WORLD_ROT_ANGLE: [ -0.78539816, 0.78539816 ] + + - NAME: random_world_scaling + WORLD_SCALE_RANGE: [ 0.95, 1.05 ] + + DATA_PROCESSOR: + - NAME: mask_points_and_boxes_outside_range + REMOVE_OUTSIDE_BOXES: True + + - NAME: shuffle_points + SHUFFLE_ENABLED: { + 'train': True, + 'test': True + } + + + POINT_FEATURE_ENCODING: { + encoding_type: absolute_coordinates_encoding, + used_feature_list: ['x', 'y', 'z', 'intensity', 'elongation', 'time'], + src_feature_list: ['x', 'y', 'z', 'intensity', 'elongation', 'time'], + } + + +MODEL: + NAME: MPPNet + + ROI_HEAD: + NAME: MPPNetHead + TRANS_INPUT: 64 + CLASS_AGNOSTIC: True + USE_BOX_ENCODING: + ENABLED: True + AVG_STAGE1_SCORE: True + USE_TRAJ_EMPTY_MASK: True + USE_AUX_LOSS: True + USE_MLP_JOINTEMB: False + IOU_WEIGHT: [0.5,0.4] + + + + ROI_GRID_POOL: + GRID_SIZE: 4 + MLPS: [[64,64]] + POOL_RADIUS: [0.8] + NSAMPLE: [16] + POOL_METHOD: max_pool + + + Transformer: + num_lidar_points: 128 + num_proxy_points: 64 # GRID_SIZE*GRID_SIZE*GRID_SIZE + pos_hidden_dim: 64 + enc_layers: 3 + dim_feedforward: 512 + hidden_dim: 64 #equal to ROI_HEAD.TRANS_INPUT + dropout: 0.1 + nheads: 4 + pre_norm: False + num_frames: 16 + num_groups: 4 + sequence_stride: 4 + use_grid_pos: + enabled: True + init_type: index + + use_mlp_mixer: + enabled: True + hidden_dim: 16 + + TARGET_CONFIG: + BOX_CODER: ResidualCoder + ROI_PER_IMAGE: 96 + FG_RATIO: 0.5 + REG_AUG_METHOD: single + ROI_FG_AUG_TIMES: 10 + RATIO: 0.2 + USE_ROI_AUG: True + USE_TRAJ_AUG: + ENABLED: True + THRESHOD: 0.8 + SAMPLE_ROI_BY_EACH_CLASS: True + CLS_SCORE_TYPE: roi_iou + + CLS_FG_THRESH: 0.75 + CLS_BG_THRESH: 0.25 + CLS_BG_THRESH_LO: 0.1 + HARD_BG_RATIO: 0.8 + + REG_FG_THRESH: 0.55 + + LOSS_CONFIG: + CLS_LOSS: BinaryCrossEntropy + REG_LOSS: smooth-l1 + CORNER_LOSS_REGULARIZATION: True + LOSS_WEIGHTS: { + 'rcnn_cls_weight': 1.0, + 'rcnn_reg_weight': 1.0, + 'rcnn_corner_weight': 2.0, + 'traj_reg_weight': [2.0, 2.0, 2.0], + 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + } + + POST_PROCESSING: + RECALL_THRESH_LIST: [0.3, 0.5, 0.7] + SCORE_THRESH: 0.1 + OUTPUT_RAW_SCORE: False + SAVE_BBOX: False + EVAL_METRIC: waymo + NOT_APPLY_NMS_FOR_VEL: True + + NMS_CONFIG: + MULTI_CLASSES_NMS: False + NMS_TYPE: nms_gpu + NMS_THRESH: 0.7 + NMS_PRE_MAXSIZE: 4096 + NMS_POST_MAXSIZE: 500 + + +OPTIMIZATION: + BATCH_SIZE_PER_GPU: 4 + NUM_EPOCHS: 3 + + OPTIMIZER: adam_onecycle + LR: 0.003 + WEIGHT_DECAY: 0.01 + MOMENTUM: 0.9 + + MOMS: [0.95, 0.85] + PCT_START: 0.4 + DIV_FACTOR: 10 + DECAY_STEP_LIST: [35, 45] + LR_DECAY: 0.1 + LR_CLIP: 0.0000001 + + LR_WARMUP: False + WARMUP_EPOCH: 1 + + GRAD_NORM_CLIP: 10 diff --git a/tools/cfgs/waymo_models/mppnet_4frame.yaml b/tools/cfgs/waymo_models/mppnet_4frame.yaml new file mode 100644 index 000000000..495722871 --- /dev/null +++ b/tools/cfgs/waymo_models/mppnet_4frame.yaml @@ -0,0 +1,168 @@ +CLASS_NAMES: ['Vehicle', 'Pedestrian', 'Cyclist'] + +DATA_CONFIG: + _BASE_CONFIG_: cfgs/dataset_configs/waymo_dataset.yaml + PROCESSED_DATA_TAG: 'waymo_processed_data_v0_5_0' + + SAMPLED_INTERVAL: { + 'train': 1, + 'test': 1 + } + FILTER_EMPTY_BOXES_FOR_TRAIN: True + DISABLE_NLZ_FLAG_ON_POINTS: True + + SEQUENCE_CONFIG: + ENABLED: True + SAMPLE_OFFSET: [-3,0] + + USE_PREDBOX: True + ROI_BOXES_PATH: { + 'train': '../output/xxxxx/train/result.pkl', # example: predicted boxes of RPN in training set + 'test': '../output/xxxxx/val/result.pkl', # example: predicted boxes of RPN in evalulation set + } + + DATA_AUGMENTOR: + DISABLE_AUG_LIST: [ 'placeholder' ] + AUG_CONFIG_LIST: + + - NAME: random_world_flip + ALONG_AXIS_LIST: [ 'x', 'y' ] + + - NAME: random_world_rotation + WORLD_ROT_ANGLE: [ -0.78539816, 0.78539816 ] + + - NAME: random_world_scaling + WORLD_SCALE_RANGE: [ 0.95, 1.05 ] + + DATA_PROCESSOR: + - NAME: mask_points_and_boxes_outside_range + REMOVE_OUTSIDE_BOXES: True + + - NAME: shuffle_points + SHUFFLE_ENABLED: { + 'train': True, + 'test': True + } + + POINT_FEATURE_ENCODING: { + encoding_type: absolute_coordinates_encoding, + used_feature_list: ['x', 'y', 'z', 'intensity', 'elongation','time'], + src_feature_list: ['x', 'y', 'z', 'intensity', 'elongation','time'], + } + + +MODEL: + NAME: MPPNet + + ROI_HEAD: + NAME: MPPNetHead + TRANS_INPUT: 256 + CLASS_AGNOSTIC: True + USE_BOX_ENCODING: + ENABLED: True + AVG_STAGE1_SCORE: True + USE_TRAJ_EMPTY_MASK: True + USE_AUX_LOSS: True + USE_MLP_JOINTEMB: True + IOU_WEIGHT: [0.5,0.4] + + + + ROI_GRID_POOL: + GRID_SIZE: 4 + MLPS: [[128,128], [128,128]] + POOL_RADIUS: [0.8, 1.6] + NSAMPLE: [16, 16] + POOL_METHOD: max_pool + + + Transformer: + num_lidar_points: 128 + num_proxy_points: 64 # GRID_SIZE*GRID_SIZE*GRID_SIZE + pos_hidden_dim: 64 + enc_layers: 3 + dim_feedforward: 512 + hidden_dim: 256 #equal to ROI_HEAD.TRANS_INPUT + dropout: 0.1 + nheads: 4 + pre_norm: False + num_frames: 4 + num_groups: 4 + use_grid_pos: + enabled: True + init_type: index + + use_mlp_mixer: + enabled: True + hidden_dim: 16 + + TARGET_CONFIG: + BOX_CODER: ResidualCoder + ROI_PER_IMAGE: 96 + FG_RATIO: 0.5 + REG_AUG_METHOD: single + ROI_FG_AUG_TIMES: 10 + RATIO: 0.2 + USE_ROI_AUG: True + USE_TRAJ_AUG: + ENABLED: True + THRESHOD: 0.8 + SAMPLE_ROI_BY_EACH_CLASS: True + CLS_SCORE_TYPE: roi_iou + + CLS_FG_THRESH: 0.75 + CLS_BG_THRESH: 0.25 + CLS_BG_THRESH_LO: 0.1 + HARD_BG_RATIO: 0.8 + + REG_FG_THRESH: 0.55 + + LOSS_CONFIG: + CLS_LOSS: BinaryCrossEntropy + REG_LOSS: smooth-l1 + CORNER_LOSS_REGULARIZATION: True + LOSS_WEIGHTS: { + 'rcnn_cls_weight': 1.0, + 'rcnn_reg_weight': 1.0, + 'rcnn_corner_weight': 2.0, + 'traj_reg_weight': [2.0, 2.0, 2.0], + 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + } + + POST_PROCESSING: + RECALL_THRESH_LIST: [0.3, 0.5, 0.7] + SCORE_THRESH: 0.1 + OUTPUT_RAW_SCORE: False + SAVE_BBOX: False + EVAL_METRIC: waymo + NOT_APPLY_NMS_FOR_VEL: True + + NMS_CONFIG: + MULTI_CLASSES_NMS: False + NMS_TYPE: nms_gpu + NMS_THRESH: 0.7 + NMS_PRE_MAXSIZE: 4096 + NMS_POST_MAXSIZE: 500 + + +OPTIMIZATION: + BATCH_SIZE_PER_GPU: 4 + NUM_EPOCHS: 3 + + OPTIMIZER: adam_onecycle + LR: 0.003 + WEIGHT_DECAY: 0.01 + MOMENTUM: 0.9 + + MOMS: [0.95, 0.85] + PCT_START: 0.4 + DIV_FACTOR: 10 + DECAY_STEP_LIST: [35, 45] + LR_DECAY: 0.1 + LR_CLIP: 0.0000001 + + LR_WARMUP: False + WARMUP_EPOCH: 1 + + GRAD_NORM_CLIP: 10 + diff --git a/tools/cfgs/waymo_models/mppnet_e2e_memorybank_inference.yaml b/tools/cfgs/waymo_models/mppnet_e2e_memorybank_inference.yaml new file mode 100644 index 000000000..e22483315 --- /dev/null +++ b/tools/cfgs/waymo_models/mppnet_e2e_memorybank_inference.yaml @@ -0,0 +1,235 @@ +CLASS_NAMES: ['Vehicle', 'Pedestrian', 'Cyclist'] + +DATA_CONFIG: + + _BASE_CONFIG_: cfgs/dataset_configs/waymo_dataset.yaml + PROCESSED_DATA_TAG: 'waymo_processed_data_v0_5_0' + + SAMPLED_INTERVAL: { + 'train': 1, + 'test': 1 + } + FILTER_EMPTY_BOXES_FOR_TRAIN: True + DISABLE_NLZ_FLAG_ON_POINTS: True + + SEQUENCE_CONFIG: + ENABLED: True + USE_SPEED: True + SAMPLE_OFFSET: [-3, 0] #16frame using [-15,0] + + + POINT_FEATURE_ENCODING: { + encoding_type: absolute_coordinates_encoding, + used_feature_list: ['x', 'y', 'z', 'intensity', 'elongation','time'], + src_feature_list: ['x', 'y', 'z', 'intensity', 'elongation','time'], + } + + DATA_AUGMENTOR: + DISABLE_AUG_LIST: [ 'placeholder' ] + AUG_CONFIG_LIST: + + - NAME: random_world_flip + ALONG_AXIS_LIST: [ 'x', 'y' ] + + - NAME: random_world_rotation + WORLD_ROT_ANGLE: [ -0.78539816, 0.78539816 ] + + - NAME: random_world_scaling + WORLD_SCALE_RANGE: [ 0.95, 1.05 ] + + DATA_PROCESSOR: + - NAME: mask_points_and_boxes_outside_range + REMOVE_OUTSIDE_BOXES: True + + - NAME: shuffle_points + SHUFFLE_ENABLED: { + 'train': True, + 'test': True + } + + - NAME: transform_points_to_voxels + VOXEL_SIZE: [ 0.1, 0.1, 0.15 ] + MAX_POINTS_PER_VOXEL: 5 + MAX_NUMBER_OF_VOXELS: { + 'train': 150000, + 'test': 150000 + } + + +MODEL: + NAME: MPPNetE2E + + VFE: + NAME: DynMeanVFE + + BACKBONE_3D: + NAME: VoxelResBackBone8x + + MAP_TO_BEV: + NAME: HeightCompression + NUM_BEV_FEATURES: 256 + + BACKBONE_2D: + NAME: BaseBEVBackbone + NUM_FRAME: 2 + LAYER_NUMS: [5, 5] + LAYER_STRIDES: [1, 2] + NUM_FILTERS: [128, 256] + UPSAMPLE_STRIDES: [1, 2] + NUM_UPSAMPLE_FILTERS: [256, 256] + + DENSE_HEAD: + NAME: CenterHead + CLASS_AGNOSTIC: False + + CLASS_NAMES_EACH_HEAD: [ + ['Vehicle', 'Pedestrian', 'Cyclist'] + ] + + SHARED_CONV_CHANNEL: 64 + USE_BIAS_BEFORE_NORM: True + NUM_HM_CONV: 2 + SEPARATE_HEAD_CFG: + HEAD_ORDER: ['center', 'center_z', 'dim', 'rot','vel'] + HEAD_DICT: { + 'center': {'out_channels': 2, 'num_conv': 2}, + 'center_z': {'out_channels': 1, 'num_conv': 2}, + 'dim': {'out_channels': 3, 'num_conv': 2}, + 'rot': {'out_channels': 2, 'num_conv': 2}, + 'vel': {'out_channels': 2, 'num_conv': 2}, + + } + + TARGET_ASSIGNER_CONFIG: + FEATURE_MAP_STRIDE: 8 + NUM_MAX_OBJS: 500 + GAUSSIAN_OVERLAP: 0.1 + MIN_RADIUS: 2 + + LOSS_CONFIG: + LOSS_WEIGHTS: { + 'cls_weight': 1.0, + 'loc_weight': 2.0, + 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2] + } + + POST_PROCESSING: + SCORE_THRESH: 0.1 + POST_CENTER_LIMIT_RANGE: [-75.2, -75.2, -2, 75.2, 75.2, 4] + MAX_OBJ_PER_SAMPLE: 500 + NMS_CONFIG: + NMS_TYPE: nms_gpu + NMS_THRESH: 0.7 + NMS_PRE_MAXSIZE: 4096 + NMS_POST_MAXSIZE: 500 + + ROI_HEAD: + NAME: MPPNetHeadE2E + TRANS_INPUT: 256 + CLASS_AGNOSTIC: True + USE_BOX_ENCODING: + ENABLED: True + NORM_T0: True + ALL_YAW_T0: True + AVG_STAGE_1: True + USE_TRAJ_EMPTY_MASK: True + USE_AUX_LOSS: True + USE_MLP_JOINTEMB: True + IOU_WEIGHT: [0.5,0.4] + + ROI_GRID_POOL: + GRID_SIZE: 4 + MLPS: [[128,128], [128,128]] + POOL_RADIUS: [0.8, 1.6] + NSAMPLE: [16, 16] + POOL_METHOD: max_pool + + Transformer: + num_lidar_points: 128 + num_proxy_points: 64 + pos_hidden_dim: 64 + enc_layers: 3 + dim_feedforward: 512 + hidden_dim: 256 + dropout: 0.1 + nheads: 4 + pre_norm: False + num_frames: 4 #16frame using 16 + num_groups: 4 + sequence_stride: 1 #16frame using 4 + use_grid_pos: + enabled: True + init_type: index + use_mlp_mixer: + enabled: True + hidden_dim: 16 + + TARGET_CONFIG: + BOX_CODER: ResidualCoder + ROI_PER_IMAGE: 96 + FG_RATIO: 0.5 + REG_AUG_METHOD: single + ROI_FG_AUG_TIMES: 10 + RATIO: 0.2 + USE_ROI_AUG: True + USE_TRAJ_AUG: + ENABLED: True + THRESHOD: 0.8 + SAMPLE_ROI_BY_EACH_CLASS: True + CLS_SCORE_TYPE: roi_iou + + CLS_FG_THRESH: 0.75 + CLS_BG_THRESH: 0.25 + CLS_BG_THRESH_LO: 0.1 + HARD_BG_RATIO: 0.8 + + REG_FG_THRESH: 0.55 + + LOSS_CONFIG: + CLS_LOSS: BinaryCrossEntropy + REG_LOSS: smooth-l1 + CORNER_LOSS_REGULARIZATION: True + LOSS_WEIGHTS: { + 'rcnn_cls_weight': 1.0, + 'rcnn_reg_weight': 1.0, + 'rcnn_corner_weight': 2.0, + 'traj_reg_weight': [2.0, 2.0, 2.0], + 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + } + + POST_PROCESSING: + RECALL_THRESH_LIST: [0.3, 0.5, 0.7] + SCORE_THRESH: 0.1 + OUTPUT_RAW_SCORE: False + SAVE_BBOX: False + EVAL_METRIC: waymo + NOT_APPLY_NMS_FOR_VEL: True + + NMS_CONFIG: + MULTI_CLASSES_NMS: False + NMS_TYPE: nms_gpu + NMS_THRESH: 0.7 + NMS_PRE_MAXSIZE: 4096 + NMS_POST_MAXSIZE: 500 + + +OPTIMIZATION: + BATCH_SIZE_PER_GPU: 4 + NUM_EPOCHS: 36 + + OPTIMIZER: adam_onecycle + LR: 0.003 + WEIGHT_DECAY: 0.01 + MOMENTUM: 0.9 + + MOMS: [0.95, 0.85] + PCT_START: 0.4 + DIV_FACTOR: 10 + DECAY_STEP_LIST: [35, 45] + LR_DECAY: 0.1 + LR_CLIP: 0.0000001 + + LR_WARMUP: False + WARMUP_EPOCH: 1 + + GRAD_NORM_CLIP: 10 \ No newline at end of file diff --git a/tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet_2frames.yaml b/tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet_2frames.yaml index 64826fa30..955a64727 100644 --- a/tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet_2frames.yaml +++ b/tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet_2frames.yaml @@ -124,7 +124,7 @@ MODEL: NAME: VectorPoolAggregationModuleMSG NUM_GROUPS: 2 LOCAL_AGGREGATION_TYPE: local_interpolation - NUM_REDUCED_CHANNELS: 2 + NUM_REDUCED_CHANNELS: 3 NUM_CHANNELS_OF_LOCAL_AGGREGATION: 32 MSG_POST_MLPS: [ 32 ] FILTER_NEIGHBOR_WITH_ROI: True @@ -296,7 +296,7 @@ MODEL: OPTIMIZATION: BATCH_SIZE_PER_GPU: 2 - NUM_EPOCHS: 30 + NUM_EPOCHS: 36 OPTIMIZER: adam_onecycle LR: 0.01 @@ -313,4 +313,4 @@ OPTIMIZATION: LR_WARMUP: False WARMUP_EPOCH: 1 - GRAD_NORM_CLIP: 10 \ No newline at end of file + GRAD_NORM_CLIP: 10 diff --git a/tools/eval_utils/eval_utils.py b/tools/eval_utils/eval_utils.py index 64503bffb..8e4b0fe3a 100644 --- a/tools/eval_utils/eval_utils.py +++ b/tools/eval_utils/eval_utils.py @@ -63,6 +63,7 @@ def eval_one_epoch(cfg, args, model, dataloader, epoch_id, logger, dist_test=Fal with torch.no_grad(): pred_dicts, ret_dict = model(batch_dict) + disp_dict = {} if getattr(args, 'infer_time', False): diff --git a/tools/test.py b/tools/test.py index 5f856be7a..51b7178c6 100644 --- a/tools/test.py +++ b/tools/test.py @@ -26,6 +26,7 @@ def parse_config(): parser.add_argument('--workers', type=int, default=4, help='number of workers for dataloader') parser.add_argument('--extra_tag', type=str, default='default', help='extra tag for this experiment') parser.add_argument('--ckpt', type=str, default=None, help='checkpoint to start from') + parser.add_argument('--pretrained_model', type=str, default=None, help='pretrained_model') parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none') parser.add_argument('--tcp_port', type=int, default=18888, help='tcp port for distrbuted training') parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training') @@ -56,9 +57,10 @@ def parse_config(): def eval_single_ckpt(model, test_loader, args, eval_output_dir, logger, epoch_id, dist_test=False): # load checkpoint - model.load_params_from_file(filename=args.ckpt, logger=logger, to_cpu=dist_test) + model.load_params_from_file(filename=args.ckpt, logger=logger, to_cpu=dist_test, + pre_trained_path=args.pretrained_model) model.cuda() - + # start evaluation eval_utils.eval_one_epoch( cfg, args, model, test_loader, epoch_id, logger, dist_test=dist_test,