From 1574d4a494968d51e42e46a4bc21e65f377c547e Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Sun, 17 Jul 2022 16:40:12 +0800 Subject: [PATCH 1/5] Add custom dataset support For kitti-like custom datasets, this will help. --- pcdet/datasets/__init__.py | 4 +- pcdet/datasets/custom/__init__.py | 0 pcdet/datasets/custom/custom_dataset.py | 371 ++++++++++++++++++++++++ pcdet/datasets/custom/custom_utils.py | 0 pcdet/utils/object3d_custom.py | 83 ++++++ 5 files changed, 457 insertions(+), 1 deletion(-) create mode 100644 pcdet/datasets/custom/__init__.py create mode 100644 pcdet/datasets/custom/custom_dataset.py create mode 100644 pcdet/datasets/custom/custom_utils.py create mode 100644 pcdet/utils/object3d_custom.py diff --git a/pcdet/datasets/__init__.py b/pcdet/datasets/__init__.py index a22356197..f85c91d8b 100644 --- a/pcdet/datasets/__init__.py +++ b/pcdet/datasets/__init__.py @@ -11,6 +11,7 @@ from .waymo.waymo_dataset import WaymoDataset from .pandaset.pandaset_dataset import PandasetDataset from .lyft.lyft_dataset import LyftDataset +from .custom.custom_dataset import CustomDataset __all__ = { 'DatasetTemplate': DatasetTemplate, @@ -18,7 +19,8 @@ 'NuScenesDataset': NuScenesDataset, 'WaymoDataset': WaymoDataset, 'PandasetDataset': PandasetDataset, - 'LyftDataset': LyftDataset + 'LyftDataset': LyftDataset, + 'CustomDataset': CustomDataset } diff --git a/pcdet/datasets/custom/__init__.py b/pcdet/datasets/custom/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pcdet/datasets/custom/custom_dataset.py b/pcdet/datasets/custom/custom_dataset.py new file mode 100644 index 000000000..2fe2e2b60 --- /dev/null +++ b/pcdet/datasets/custom/custom_dataset.py @@ -0,0 +1,371 @@ +import copy +import pickle +import os + +import numpy as np +from skimage import io + +from . import custom_utils +from ...ops.roiaware_pool3d import roiaware_pool3d_utils +from ...utils import box_utils, common_utils, object3d_custom +from ..dataset import DatasetTemplate + +class CustomDataset(DatasetTemplate): + def __init__(self, dataset_cfg, class_names, training=True, root_path=None, logger=None, ext='.bin'): + """ + Args: + root_path: + dataset_cfg: + class_names: + training: + logger: + """ + super().__init__( + dataset_cfg=dataset_cfg, class_names=class_names, training=training, root_path=root_path, logger=logger + ) + self.split = self.dataset_cfg.DATA_SPLIT[self.mode] + self.root_split_path = os.path.join(self.root_path, ('training' if self.split != 'test' else 'testing')) + + split_dir = os.path.join(self.root_path, 'ImageSets',(self.split + '.txt')) + self.sample_id_list = [x.strip() for x in open(split_dir).readlines()] if os.path.exists(split_dir) else None + + self.custom_infos = [] + self.include_custom_data(self.mode) + self.ext = ext + + + def include_custom_data(self, mode): + if self.logger is not None: + self.logger.info('Loading Custom dataset.') + custom_infos = [] + + for info_path in self.dataset_cfg.INFO_PATH[mode]: + info_path = self.root_path / info_path + if not info_path.exists(): + continue + with open(info_path, 'rb') as f: + infos = pickle.load(f) + custom_infos.extend(infos) + + self.custom_infos.extend(custom_infos) + + if self.logger is not None: + self.logger.info('Total samples for CUSTOM dataset: %d' % (len(custom_infos))) + + + def get_infos(self, num_workers=16, has_label=True, count_inside_pts=True, sample_id_list=None): + import concurrent.futures as futures + + # Process single scene + def process_single_scene(sample_idx): + print('%s sample_idx: %s' % (self.split, sample_idx)) + info = {} + pc_info = {'num_features': 4, 'lidar_idx': sample_idx} + info['point_cloud'] = pc_info + + # no images, calibs are need to transform the labels + + type_to_id = {'Car': 1, 'Pedestrian': 2, 'Cyclist': 3} + if has_label: + obj_list = self.get_label(sample_idx) + annotations = {} + annotations['name'] = np.array([obj.cls_type for obj in obj_list]) # 1-dimension + annotations['dimensions'] = np.array([[obj.l, obj.h, obj.w] for obj in obj_list]) + annotations['location'] = np.concatenate([obj.loc.reshape(1,3) for obj in obj_list]) + annotations['rotation_y'] = np.array([obj.ry for obj in obj_list]) # 1-dimension + + num_objects = len([obj.cls_type for obj in obj_list if obj.cls_type != 'DontCare']) + num_gt = len(annotations['name']) + index = list(range(num_objects)) + [-1] * (num_gt - num_objects) + annotations['index'] = np.array(index, dtype=np.int32) + + loc = annotations['location'][:num_objects] + dims = annotations['dimensions'][:num_objects] + rots = annotations['rotation_y'][:num_objects] + loc_lidar = self.get_calib(loc) + l, h, w = dims[:, 0:1], dims[:, 1:2], dims[:, 2:3] + gt_boxes_lidar = np.concatenate([loc_lidar, l, w, h, (np.pi / 2 - rots[..., np.newaxis])], axis=1) # 2-dimension array + annotations['gt_boxes_lidar'] = gt_boxes_lidar + + info['annos'] = annotations + + return info + + sample_id_list = sample_id_list if sample_id_list is not None else self.sample_id_list + + # create a thread pool to improve the velocity + with futures.ThreadPoolExecutor(num_workers) as executor: + infos = executor.map(process_single_scene, sample_id_list) + + return list(infos) + + + def get_calib(self, loc): + """ + This calibration is different from the kitti dataset. + The transform formual of labelCloud: ROOT/labelCloud/io/labels/kitti.py: import labels + if self.transformed: + centroid = centroid[2], -centroid[0], centroid[1] - 2.3 + dimensions = [float(v) for v in line_elements[8:11]] + if self.transformed: + dimensions = dimensions[2], dimensions[1], dimensions[0] + bbox = BBox(*centroid, *dimensions) + """ + loc_lidar = np.concatenate([np.array((float(loc_obj[2]), float(-loc_obj[0]), float(loc_obj[1]-2.3)), dtype=np.float32).reshape(1,3) for loc_obj in loc]) + return loc_lidar + + + def get_label(self, idx): + + label_file = self.root_split_path / 'label_2' / ('%s.txt' % idx) + assert label_file.exists() + return object3d_custom.get_objects_from_label(label_file) + + + def get_lidar(self, idx, getitem): + """ + Loads point clouds for a sample + Args: + index (int): Index of the point cloud file to get. + Returns: + np.array(N, 4): point cloud. + """ + # get lidar statistics + if getitem == True: + lidar_file = self.root_split_path + '/velodyne/' + ('%s.bin' % idx) + else: + lidar_file = self.root_split_path / 'velodyne' / ('%s.bin' % idx) + return np.fromfile(str(lidar_file), dtype=np.float32).reshape(-1, 4) + + + def set_split(self, split): + super().__init__( + dataset_cfg=self.dataset_cfg, class_names=self.class_names, training=self.training, root_path=self.root_path, logger=self.logger + ) + self.split = split + self.root_split_path = self.root_path / ('training' if self.split != 'test' else 'testing') + + split_dir = self.root_path / 'ImageSets' / (self.split + '.txt') + self.sample_id_list = [x.strip() for x in open(split_dir).readlines()] if split_dir.exists() else None + + + # Create gt database for data augmentation + def create_groundtruth_database(self, info_path=None, used_classes=None, split='train'): + import torch + + database_save_path = Path(self.root_path) / ('gt_database' if split == 'train' else ('gt_database_%s' % split)) + db_info_save_path = Path(self.root_path) / ('custom_dbinfos_%s.pkl' % split) + + database_save_path.mkdir(parents=True, exist_ok=True) + all_db_infos = {} + + with open(info_path, 'rb') as f: + infos = pickle.load(f) + + # For each .bin file + for k in range(len(infos)): + print('gt_database sample: %d/%d' % (k + 1, len(infos))) + info = infos[k] + sample_idx = info['point_cloud']['lidar_idx'] + points = self.get_lidar(sample_idx, False) + annos = info['annos'] + names = annos['name'] + gt_boxes = annos['gt_boxes_lidar'] + + num_obj = gt_boxes.shape[0] + point_indices = roiaware_pool3d_utils.points_in_boxes_cpu( + torch.from_numpy(points[:, 0:3]), torch.from_numpy(gt_boxes) + ).numpy() # (nboxes, npoints) + + for i in range(num_obj): + filename = '%s_%s_%d.bin' % (sample_idx, names[i], i) + filepath = database_save_path / filename + gt_points = points[point_indices[i] > 0] + + gt_points[:, :3] -= gt_boxes[i, :3] + with open(filepath, 'w') as f: + gt_points.tofile(f) + + if (used_classes is None) or names[i] in used_classes: + db_path = str(filepath.relative_to(self.root_path)) # gt_database/xxxxx.bin + db_info = {'name': names[i], 'path': db_path, 'gt_idx': i, + 'box3d_lidar': gt_boxes[i], 'num_points_in_gt': gt_points.shape[0]} + if names[i] in all_db_infos: + all_db_infos[names[i]].append(db_info) + else: + all_db_infos[names[i]] = [db_info] + + # Output the num of all classes in database + for k, v in all_db_infos.items(): + print('Database %s: %d' % (k, len(v))) + + with open(db_info_save_path, 'wb') as f: + pickle.dump(all_db_infos, f) + + @staticmethod + def generate_prediction_dicts(batch_dict, pred_dicts, class_names, output_path=None): + """ + Args: + batch_dict: + frame_id: + pred_dicts: list of pred_dicts + pred_boxes: (N,7), Tensor + pred_scores: (N), Tensor + pred_lables: (N), Tensor + class_names: + output_path: + + Returns: + + """ + def get_template_prediction(num_smaples): + ret_dict = { + 'name': np.zeros(num_smaples), 'alpha' : np.zeros(num_smaples), + 'dimensions': np.zeros([num_smaples, 3]), 'location': np.zeros([num_smaples, 3]), + 'rotation_y': np.zero(num_smaples), 'score': np.zeros(num_smaples), + 'boxes_lidar': np.zeros([num_smaples, 7]) + } + return ret_dict + + def generate_single_sample_dict(batch_index, box_dict): + pred_scores = box_dict['pred_scores'].cpu().numpy() + pred_boxes = box_dict['pred_boxes'].cpu().numpy() + pred_labels = box_dict['pred_labels'].cpu().numpy() + + # Define an empty template dict to store the prediction information, 'pred_scores.shape[0]' means 'num_samples' + pred_dict = get_template_prediction(pred_scores.shape[0]) + # If num_samples equals zero then return the empty dict + if pred_scores.shape[0] == 0: + return pred_dict + + # No calibration files + + pred_boxes_camera = box_utils.boxes3d_lidar_to_kitti_camera[pred_boxes] + + pred_dict['name'] = np.array(class_names)[pred_labels - 1] + pred_dict['alpha'] = -np.arctan2(-pred_boxes[:, 1], pred_boxes[:, 0]) + pred_boxes_camera[:, 6] + pred_dict['dimensions'] = pred_boxes_camera[:, 3:6] + pred_dict['location'] = pred_boxes_camera[:, 0:3] + pred_dict['rotation_y'] = pred_boxes_camera[:, 6] + pred_dict['score'] = pred_scores + pred_dict['boxes_lidar'] = pred_boxes + + return pred_dict + + annos = [] + for index, box_dict in enumerate(pred_dicts): + frame_id = batch_dict['frame_id'][index] + + single_pred_dict = generate_single_sample_dict(index, box_dict) + single_pred_dict['frame_id'] = frame_id + annos.append(single_pred_dict) + + # Output pred results to Output-path in .txt file + if output_path is not None: + cur_det_file = output_path / ('%s.txt' % frame_id) + with open(cur_det_file, 'w') as f: + bbox = single_pred_dict['bbox'] + loc = single_pred_dict['location'] + dims = single_pred_dict['dimensions'] # lhw -> hwl: lidar -> camera + + for idx in range(len(bbox)): + print('%s -1 -1 %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f' + % (single_pred_dict['name'][idx], single_pred_dict['alpha'][idx], + bbox[idx][0], bbox[idx][1], bbox[idx][2], bbox[idx][3], + dims[idx][1], dims[idx][2], dims[idx][0], loc[idx][0], + loc[idx][1], loc[idx][2], single_pred_dict['rotation_y'][idx], + single_pred_dict['score'][idx]), file=f) + return annos + + + def __len__(self): + if self._merge_all_iters_to_one_epoch: + return len(self.sample_id_list) * self.total_epochs + + return len(self.custom_infos) + + + def __getitem__(self, index): + """ + Function: + Read 'velodyne' folder as pointclouds + Read 'label_2' folder as labels + Return type 'dict' + """ + if self._merge_all_iters_to_one_epoch: + index = index % len(self.custom_infos) + + info = copy.deepcopy(self.custom_infos[index]) + + sample_idx = info['point_cloud']['lidar_idx'] + get_item_list = self.dataset_cfg.get('GET_ITEM_LIST', ['points']) + + input_dict = { + 'frame_id': self.sample_id_list[index], + } + + """ + Here infos was generated by get_infos + """ + if 'annos' in info: + annos = info['annos'] + annos = common_utils.drop_info_with_name(annos, name='DontCare') + loc, dims, rots = annos['location'], annos['dimensions'], annos['rotation_y'] + gt_names = annos['name'] + gt_boxes_lidar = annos['gt_boxes_lidar'] + + if 'points' in get_item_list: + points = self.get_lidar(sample_idx, True) + input_dict['points'] = points + input_dict.update({ + 'gt_names': gt_names, + 'gt_boxes': gt_boxes_lidar + }) + + data_dict = self.prepare_data(data_dict=input_dict) + return data_dict + + +def create_custom_infos(dataset_cfg, class_names, data_path, save_path, workers=4): + dataset = CustomDataset(dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path, training=False) + train_split, val_split = 'train', 'val' + + # No evaluation + train_filename = save_path / ('custom_infos_%s.pkl' % train_split) + val_filenmae = save_path / ('custom_infos%s.pkl' % val_split) + trainval_filename = save_path / 'custom_infos_trainval.pkl' + test_filename = save_path / 'custom_infos_test.pkl' + + print('------------------------Start to generate data infos------------------------') + + dataset.set_split(train_split) + custom_infos_train = dataset.get_infos(num_workers=workers, has_label=True, count_inside_pts=True) + with open(train_filename, 'wb') as f: + pickle.dump(custom_infos_train, f) + print('Custom info train file is save to %s' % train_filename) + + dataset.set_split('test') + custom_infos_test = dataset.get_infos(num_workers=workers, has_label=False, count_inside_pts=False) + with open(test_filename, 'wb') as f: + pickle.dump(custom_infos_test, f) + print('Custom info test file is saved to %s' % test_filename) + + print('------------------------Start create groundtruth database for data augmentation------------------------') + dataset.set_split(train_split) + dataset.create_groundtruth_database(train_filename, split=train_split) + print('------------------------Data preparation done------------------------') + +if __name__=='__main__': + import sys + if sys.argv.__len__() > 1 and sys.argv[1] == 'create_custom_infos': + import yaml + from pathlib import Path + from easydict import EasyDict + dataset_cfg = EasyDict(yaml.safe_load(open(sys.argv[2]))) + ROOT_DIR = (Path(__file__).resolve().parent / '../../../').resolve() + create_custom_infos( + dataset_cfg=dataset_cfg, + class_names=['Car', 'Pedestrian', 'Cyclist'], + data_path=ROOT_DIR / 'data' / 'custom', + save_path=ROOT_DIR / 'data' / 'custom' + ) diff --git a/pcdet/datasets/custom/custom_utils.py b/pcdet/datasets/custom/custom_utils.py new file mode 100644 index 000000000..e69de29bb diff --git a/pcdet/utils/object3d_custom.py b/pcdet/utils/object3d_custom.py new file mode 100644 index 000000000..188d1bb0b --- /dev/null +++ b/pcdet/utils/object3d_custom.py @@ -0,0 +1,83 @@ +import numpy as np + + +def get_objects_from_label(label_file): + with open(label_file, 'r') as f: + lines = f.readlines() + objects = [Object3d(line) for line in lines] + return objects + + +def cls_type_to_id(cls_type): + type_to_id = {'Car': 1, 'Pedestrian': 2, 'Cyclist': 3, 'Van': 4} + if cls_type not in type_to_id.keys(): + return -1 + return type_to_id[cls_type] + + +class Object3d(object): + def __init__(self, line): + label = line.strip().split(' ') + self.src = line + self.cls_type = label[0] + self.cls_id = cls_type_to_id(self.cls_type) + self.truncation = float(label[1]) + self.occlusion = float(label[2]) # 0:fully visible 1:partly occluded 2:largely occluded 3:unknown + self.alpha = float(label[3]) + self.box2d = np.array((float(label[4]), float(label[5]), float(label[6]), float(label[7])), dtype=np.float32) + self.h = float(label[8]) + self.w = float(label[9]) + self.l = float(label[10]) + self.loc = np.array((float(label[11]), float(label[12]), float(label[13])), dtype=np.float32) + self.dis_to_cam = np.linalg.norm(self.loc) + self.ry = float(label[14]) + self.score = float(label[15]) if label.__len__() == 16 else -1.0 + self.level_str = None + self.level = self.get_custom_obj_level() + + def get_custom_obj_level(self): + height = float(self.box2d[3]) - float(self.box2d[1]) + 1 + + if height >= 40 and self.truncation <= 0.15 and self.occlusion <= 0: + self.level_str = 'Easy' + return 0 # Easy + elif height >= 25 and self.truncation <= 0.3 and self.occlusion <= 1: + self.level_str = 'Moderate' + return 1 # Moderate + elif height >= 25 and self.truncation <= 0.5 and self.occlusion <= 2: + self.level_str = 'Hard' + return 2 # Hard + else: + self.level_str = 'UnKnown' + return -1 + + def generate_corners3d(self): + """ + generate corners3d representation for this object + :return corners_3d: (8, 3) corners of box3d in camera coord + """ + l, h, w = self.l, self.h, self.w + x_corners = [l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2] + y_corners = [0, 0, 0, 0, -h, -h, -h, -h] + z_corners = [w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2] + + R = np.array([[np.cos(self.ry), 0, np.sin(self.ry)], + [0, 1, 0], + [-np.sin(self.ry), 0, np.cos(self.ry)]]) + corners3d = np.vstack([x_corners, y_corners, z_corners]) # (3, 8) + corners3d = np.dot(R, corners3d).T + corners3d = corners3d + self.loc + return corners3d + + def to_str(self): + print_str = '%s %.3f %.3f %.3f box2d: %s hwl: [%.3f %.3f %.3f] pos: %s ry: %.3f' \ + % (self.cls_type, self.truncation, self.occlusion, self.alpha, self.box2d, self.h, self.w, self.l, + self.loc, self.ry) + return print_str + + def to_custom_format(self): + custom_str = '%s %.2f %d %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f' \ + % (self.cls_type, self.truncation, int(self.occlusion), self.alpha, self.box2d[0], self.box2d[1], + self.box2d[2], self.box2d[3], self.h, self.w, self.l, self.loc[0], self.loc[1], self.loc[2], + self.ry) + return custom_str From e9b06b361c5a65b1ed73aa097ccbacc710b74ae5 Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Sun, 17 Jul 2022 17:35:04 +0800 Subject: [PATCH 2/5] add config files and README --- pcdet/datasets/custom/README.md | 37 +++ pcdet/datasets/custom/custom_dataset.py | 8 +- tools/cfgs/custom_models/pointrcnn.yaml | 161 +++++++++++ tools/cfgs/custom_models/pv_rcnn.yaml | 249 ++++++++++++++++++ .../cfgs/dataset_configs/custom_dataset.yaml | 71 +++++ 5 files changed, 519 insertions(+), 7 deletions(-) create mode 100644 pcdet/datasets/custom/README.md create mode 100644 tools/cfgs/custom_models/pointrcnn.yaml create mode 100644 tools/cfgs/custom_models/pv_rcnn.yaml create mode 100644 tools/cfgs/dataset_configs/custom_dataset.yaml diff --git a/pcdet/datasets/custom/README.md b/pcdet/datasets/custom/README.md new file mode 100644 index 000000000..398fa4c83 --- /dev/null +++ b/pcdet/datasets/custom/README.md @@ -0,0 +1,37 @@ +For Custom Dataset using +## Custom Dataset +For pure point cloud dataset, which means you don't have images generated when got point cloud data from a self-defined scene. Label those raw data and make sure label files to be kitti-like: +``` +Car 0 0 0 0 0 0 0 1.50 1.46 3.70 -5.12 1.85 4.13 1.56 +Pedestrian 0 0 0 0 0 0 0 1.54 0.57 0.41 -7.92 1.94 15.95 1.57 +DontCare 0 0 0 0 0 0 0 -1 -1 -1 -1000 -1000 -1000 -10 +``` +Some items (which is shown from the first zero to the seventh zero above) are not necessary because they are meaningless if no cameras. And the `image` folder, `calib` folder are both needless, which will be much more convenient for not using just official dataset. The point cloud dataset should be `.bin` format. + +Place the custom dataset: +``` +OpenPCDet +├── data +│ ├── custom +│ │ │── ImageSets +│ │ │── training +│ │ │ ├──velodyne & label_2 +│ │ │── testing +│ │ │ ├──velodyne +├── pcdet +├── tools +``` +## Calibration +Calibration rules for cameras are not need. But you need to define how to transform from KITTI coordinates to lidar coordinates. The lidar coordinates are the custom coordinates. The raw data are in lidar coordinates and the labels are in KITTI coordinates. This self-defined transform is written in `custom_dataset->get_calib (188)` which is used to get gt_boxes from labels. +## Other configurations +Possible other parameters or names that need to be check to adapt the custom scene. +- config files + ``` + CLASS_NAMES: ['Car', 'Pedestrian', 'Cyclist'] # pv_rcnn.yaml + ... +'anchor_sizes': [[3.9, 1.6, 1.56]], # pv_rcnn.yaml +... +POINT_CLOUD_RANGE: [-70.4, -40, -3, 70.4, 40, 1] # custom_dataset.yaml +... + ``` +The train, test and pred are all the same as others. \ No newline at end of file diff --git a/pcdet/datasets/custom/custom_dataset.py b/pcdet/datasets/custom/custom_dataset.py index 2fe2e2b60..1c6ff705a 100644 --- a/pcdet/datasets/custom/custom_dataset.py +++ b/pcdet/datasets/custom/custom_dataset.py @@ -103,13 +103,7 @@ def process_single_scene(sample_idx): def get_calib(self, loc): """ This calibration is different from the kitti dataset. - The transform formual of labelCloud: ROOT/labelCloud/io/labels/kitti.py: import labels - if self.transformed: - centroid = centroid[2], -centroid[0], centroid[1] - 2.3 - dimensions = [float(v) for v in line_elements[8:11]] - if self.transformed: - dimensions = dimensions[2], dimensions[1], dimensions[0] - bbox = BBox(*centroid, *dimensions) + You should check or redefine it according to your condition. """ loc_lidar = np.concatenate([np.array((float(loc_obj[2]), float(-loc_obj[0]), float(loc_obj[1]-2.3)), dtype=np.float32).reshape(1,3) for loc_obj in loc]) return loc_lidar diff --git a/tools/cfgs/custom_models/pointrcnn.yaml b/tools/cfgs/custom_models/pointrcnn.yaml new file mode 100644 index 000000000..2df123581 --- /dev/null +++ b/tools/cfgs/custom_models/pointrcnn.yaml @@ -0,0 +1,161 @@ +CLASS_NAMES: ['Car'] +# CLASS_NAMES: ['Car', 'Pedestrian', 'Cyclist'] + +DATA_CONFIG: + _BASE_CONFIG_: ../dataset_configs/custom_dataset.yaml + + DATA_PROCESSOR: + - NAME: mask_points_and_boxes_outside_range + REMOVE_OUTSIDE_BOXES: True + + - NAME: sample_points + NUM_POINTS: { + 'train': 16384, + 'test': 16384 + } + + - NAME: shuffle_points + SHUFFLE_ENABLED: { + 'train': True, + 'test': False + } + +MODEL: + NAME: PointRCNN + + BACKBONE_3D: + NAME: PointNet2MSG + SA_CONFIG: + NPOINTS: [4096, 1024, 256, 64] + RADIUS: [[0.1, 0.5], [0.5, 1.0], [1.0, 2.0], [2.0, 4.0]] + NSAMPLE: [[16, 32], [16, 32], [16, 32], [16, 32]] + MLPS: [[[16, 16, 32], [32, 32, 64]], + [[64, 64, 128], [64, 96, 128]], + [[128, 196, 256], [128, 196, 256]], + [[256, 256, 512], [256, 384, 512]]] + FP_MLPS: [[128, 128], [256, 256], [512, 512], [512, 512]] + + POINT_HEAD: + NAME: PointHeadBox + CLS_FC: [256, 256] + REG_FC: [256, 256] + CLASS_AGNOSTIC: False + USE_POINT_FEATURES_BEFORE_FUSION: False + TARGET_CONFIG: + GT_EXTRA_WIDTH: [0.2, 0.2, 0.2] + BOX_CODER: PointResidualCoder + BOX_CODER_CONFIG: { + 'use_mean_size': True, + 'mean_size': [ + [3.9, 1.6, 1.56], + [0.8, 0.6, 1.73], + [1.76, 0.6, 1.73] + ] + } + + LOSS_CONFIG: + LOSS_REG: WeightedSmoothL1Loss + LOSS_WEIGHTS: { + 'point_cls_weight': 1.0, + 'point_box_weight': 1.0, + 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + } + + ROI_HEAD: + NAME: PointRCNNHead + CLASS_AGNOSTIC: True + + ROI_POINT_POOL: + POOL_EXTRA_WIDTH: [0.0, 0.0, 0.0] + NUM_SAMPLED_POINTS: 512 + DEPTH_NORMALIZER: 70.0 + + XYZ_UP_LAYER: [128, 128] + CLS_FC: [256, 256] + REG_FC: [256, 256] + DP_RATIO: 0.0 + USE_BN: False + + SA_CONFIG: + NPOINTS: [128, 32, -1] + RADIUS: [0.2, 0.4, 100] + NSAMPLE: [16, 16, 16] + MLPS: [[128, 128, 128], + [128, 128, 256], + [256, 256, 512]] + + NMS_CONFIG: + TRAIN: + NMS_TYPE: nms_gpu + MULTI_CLASSES_NMS: False + NMS_PRE_MAXSIZE: 9000 + NMS_POST_MAXSIZE: 512 + NMS_THRESH: 0.8 + TEST: + NMS_TYPE: nms_gpu + MULTI_CLASSES_NMS: False + NMS_PRE_MAXSIZE: 9000 + NMS_POST_MAXSIZE: 100 + NMS_THRESH: 0.85 + + TARGET_CONFIG: + BOX_CODER: ResidualCoder + ROI_PER_IMAGE: 128 + FG_RATIO: 0.5 + + SAMPLE_ROI_BY_EACH_CLASS: True + CLS_SCORE_TYPE: cls + + CLS_FG_THRESH: 0.6 + CLS_BG_THRESH: 0.45 + CLS_BG_THRESH_LO: 0.1 + HARD_BG_RATIO: 0.8 + + REG_FG_THRESH: 0.55 + + LOSS_CONFIG: + CLS_LOSS: BinaryCrossEntropy + REG_LOSS: smooth-l1 + CORNER_LOSS_REGULARIZATION: True + LOSS_WEIGHTS: { + 'rcnn_cls_weight': 1.0, + 'rcnn_reg_weight': 1.0, + 'rcnn_corner_weight': 1.0, + 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + } + + POST_PROCESSING: + RECALL_THRESH_LIST: [0.3, 0.5, 0.7] + SCORE_THRESH: 0.1 + OUTPUT_RAW_SCORE: False + + EVAL_METRIC: kitti + + NMS_CONFIG: + MULTI_CLASSES_NMS: False + NMS_TYPE: nms_gpu + NMS_THRESH: 0.1 + NMS_PRE_MAXSIZE: 4096 + NMS_POST_MAXSIZE: 500 + + +OPTIMIZATION: + BATCH_SIZE_PER_GPU: 2 + NUM_EPOCHS: 80 + + OPTIMIZER: adam_onecycle + LR: 0.01 + WEIGHT_DECAY: 0.01 + MOMENTUM: 0.9 + + MOMS: [0.95, 0.85] + PCT_START: 0.4 + DIV_FACTOR: 10 + DECAY_STEP_LIST: [35, 45] + LR_DECAY: 0.1 + LR_CLIP: 0.0000001 + + LR_WARMUP: False + WARMUP_EPOCH: 1 + + GRAD_NORM_CLIP: 10 diff --git a/tools/cfgs/custom_models/pv_rcnn.yaml b/tools/cfgs/custom_models/pv_rcnn.yaml new file mode 100644 index 000000000..012d12ac9 --- /dev/null +++ b/tools/cfgs/custom_models/pv_rcnn.yaml @@ -0,0 +1,249 @@ +CLASS_NAMES: ['Car', 'Pedestrian', 'Cyclist'] + +DATA_CONFIG: + _BASE_CONFIG_: ../dataset_configs/custom_dataset.yaml + DATA_AUGMENTOR: + DISABLE_AUG_LIST: ['placeholder'] + AUG_CONFIG_LIST: + - NAME: gt_sampling + USE_ROAD_PLANE: False + DB_INFO_PATH: + - custom_dbinfos_train.pkl + PREPARE: { + filter_by_min_points: ['Car:5', 'Pedestrian:5', 'Cyclist:5'], + filter_by_difficulty: [-1], + } + + SAMPLE_GROUPS: ['Car:15','Pedestrian:10', 'Cyclist:10'] + NUM_POINT_FEATURES: 4 + DATABASE_WITH_FAKELIDAR: False + REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0] + LIMIT_WHOLE_SCENE: False + + - NAME: random_world_flip + ALONG_AXIS_LIST: ['x'] + + - NAME: random_world_rotation + WORLD_ROT_ANGLE: [-0.78539816, 0.78539816] + + - NAME: random_world_scaling + WORLD_SCALE_RANGE: [0.95, 1.05] + +MODEL: + NAME: PVRCNN + + VFE: + NAME: MeanVFE + + BACKBONE_3D: + NAME: VoxelBackBone8x + + MAP_TO_BEV: + NAME: HeightCompression + NUM_BEV_FEATURES: 256 + + BACKBONE_2D: + NAME: BaseBEVBackbone + + LAYER_NUMS: [5, 5] + LAYER_STRIDES: [1, 2] + NUM_FILTERS: [128, 256] + UPSAMPLE_STRIDES: [1, 2] + NUM_UPSAMPLE_FILTERS: [256, 256] + + DENSE_HEAD: + NAME: AnchorHeadSingle + CLASS_AGNOSTIC: False + + USE_DIRECTION_CLASSIFIER: True + DIR_OFFSET: 0.78539 + DIR_LIMIT_OFFSET: 0.0 + NUM_DIR_BINS: 2 + + ANCHOR_GENERATOR_CONFIG: [ + { + 'class_name': 'Car', + 'anchor_sizes': [[3.9, 1.6, 1.56]], + 'anchor_rotations': [0, 1.57], + 'anchor_bottom_heights': [-0.1], + 'align_center': False, + 'feature_map_stride': 8, + 'matched_threshold': 0.6, + 'unmatched_threshold': 0.45 + }, + { + 'class_name': 'Pedestrian', + 'anchor_sizes': [[0.05, 0.03, 0.1]], + 'anchor_rotations': [0, 1.57], + 'anchor_bottom_heights': [-0.03], + 'align_center': False, + 'feature_map_stride': 8, + 'matched_threshold': 0.5, + 'unmatched_threshold': 0.35 + }, + { + 'class_name': 'Cyclist', + 'anchor_sizes': [[0.1, 0.03, 0.1]], + 'anchor_rotations': [0, 1.57], + 'anchor_bottom_heights': [-0.03], + 'align_center': False, + 'feature_map_stride': 8, + 'matched_threshold': 0.5, + 'unmatched_threshold': 0.35 + } + ] + + TARGET_ASSIGNER_CONFIG: + NAME: AxisAlignedTargetAssigner + POS_FRACTION: -1.0 + SAMPLE_SIZE: 512 + NORM_BY_NUM_EXAMPLES: False + MATCH_HEIGHT: False + BOX_CODER: ResidualCoder + + LOSS_CONFIG: + LOSS_WEIGHTS: { + 'cls_weight': 1.0, + 'loc_weight': 2.0, + 'dir_weight': 0.2, + 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + } + + PFE: + NAME: VoxelSetAbstraction + POINT_SOURCE: raw_points + NUM_KEYPOINTS: 2048 + NUM_OUTPUT_FEATURES: 128 + SAMPLE_METHOD: FPS + + FEATURES_SOURCE: ['bev', 'x_conv1', 'x_conv2', 'x_conv3', 'x_conv4', 'raw_points'] + SA_LAYER: + raw_points: + MLPS: [[16, 16], [16, 16]] + POOL_RADIUS: [0.4, 0.8] + NSAMPLE: [16, 16] + x_conv1: + DOWNSAMPLE_FACTOR: 1 + MLPS: [[16, 16], [16, 16]] + POOL_RADIUS: [0.4, 0.8] + NSAMPLE: [16, 16] + x_conv2: + DOWNSAMPLE_FACTOR: 2 + MLPS: [[32, 32], [32, 32]] + POOL_RADIUS: [0.8, 1.2] + NSAMPLE: [16, 32] + x_conv3: + DOWNSAMPLE_FACTOR: 4 + MLPS: [[64, 64], [64, 64]] + POOL_RADIUS: [1.2, 2.4] + NSAMPLE: [16, 32] + x_conv4: + DOWNSAMPLE_FACTOR: 8 + MLPS: [[64, 64], [64, 64]] + POOL_RADIUS: [2.4, 4.8] + NSAMPLE: [16, 32] + + POINT_HEAD: + NAME: PointHeadSimple + CLS_FC: [256, 256] + CLASS_AGNOSTIC: True + USE_POINT_FEATURES_BEFORE_FUSION: True + TARGET_CONFIG: + GT_EXTRA_WIDTH: [0.2, 0.2, 0.2] + LOSS_CONFIG: + LOSS_REG: smooth-l1 + LOSS_WEIGHTS: { + 'point_cls_weight': 1.0, + } + + ROI_HEAD: + NAME: PVRCNNHead + CLASS_AGNOSTIC: True + + SHARED_FC: [256, 256] + CLS_FC: [256, 256] + REG_FC: [256, 256] + DP_RATIO: 0.3 + + NMS_CONFIG: + TRAIN: + NMS_TYPE: nms_gpu + MULTI_CLASSES_NMS: False + NMS_PRE_MAXSIZE: 9000 + NMS_POST_MAXSIZE: 512 + NMS_THRESH: 0.8 + TEST: + NMS_TYPE: nms_gpu + MULTI_CLASSES_NMS: False + NMS_PRE_MAXSIZE: 1024 + NMS_POST_MAXSIZE: 100 + NMS_THRESH: 0.7 + + ROI_GRID_POOL: + GRID_SIZE: 6 + MLPS: [[64, 64], [64, 64]] + POOL_RADIUS: [0.8, 1.6] + NSAMPLE: [16, 16] + POOL_METHOD: max_pool + + TARGET_CONFIG: + BOX_CODER: ResidualCoder + ROI_PER_IMAGE: 128 + FG_RATIO: 0.5 + + SAMPLE_ROI_BY_EACH_CLASS: True + CLS_SCORE_TYPE: roi_iou + + CLS_FG_THRESH: 0.75 + CLS_BG_THRESH: 0.25 + CLS_BG_THRESH_LO: 0.1 + HARD_BG_RATIO: 0.8 + + REG_FG_THRESH: 0.55 + + LOSS_CONFIG: + CLS_LOSS: BinaryCrossEntropy + REG_LOSS: smooth-l1 + CORNER_LOSS_REGULARIZATION: True + LOSS_WEIGHTS: { + 'rcnn_cls_weight': 1.0, + 'rcnn_reg_weight': 1.0, + 'rcnn_corner_weight': 1.0, + 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + } + + POST_PROCESSING: + RECALL_THRESH_LIST: [0.3, 0.5, 0.7] + SCORE_THRESH: 0.1 + OUTPUT_RAW_SCORE: False + + EVAL_METRIC: kitti + + NMS_CONFIG: + MULTI_CLASSES_NMS: False + NMS_TYPE: nms_gpu + NMS_THRESH: 0.1 + NMS_PRE_MAXSIZE: 4096 + NMS_POST_MAXSIZE: 500 + + +OPTIMIZATION: + BATCH_SIZE_PER_GPU: 2 + NUM_EPOCHS: 80 + + OPTIMIZER: adam_onecycle + LR: 0.01 + WEIGHT_DECAY: 0.01 + MOMENTUM: 0.9 + + MOMS: [0.95, 0.85] + PCT_START: 0.4 + DIV_FACTOR: 10 + DECAY_STEP_LIST: [35, 45] + LR_DECAY: 0.1 + LR_CLIP: 0.0000001 + + LR_WARMUP: False + WARMUP_EPOCH: 1 + + GRAD_NORM_CLIP: 10 diff --git a/tools/cfgs/dataset_configs/custom_dataset.yaml b/tools/cfgs/dataset_configs/custom_dataset.yaml new file mode 100644 index 000000000..040dacf82 --- /dev/null +++ b/tools/cfgs/dataset_configs/custom_dataset.yaml @@ -0,0 +1,71 @@ +DATASET: 'CustomDataset' +DATA_PATH: '../data/custom' + +# If this config file is modified then pcdet/models/detectors/detector3d_template.py: +# Detector3DTemplate::build_networks:model_info_dict needs to be modified. +POINT_CLOUD_RANGE: [-70.4, -40, -3, 70.4, 40, 1] # x=[-70.4, 70.4], y=[-40,40], z=[-3,1] + +DATA_SPLIT: { + 'train': train, + 'test': val +} + +INFO_PATH: { + 'train': [custom_infos_train.pkl], + 'test': [custom_infos_val.pkl], +} + +GET_ITEM_LIST: ["points"] +FOV_POINTS_ONLY: True + +POINT_FEATURE_ENCODING: { + encoding_type: absolute_coordinates_encoding, + used_feature_list: ['x', 'y', 'z', 'intensity'], + src_feature_list: ['x', 'y', 'z', 'intensity'], +} + +# Same to pv_rcnn[DATA_AUGMENTOR] +DATA_AUGMENTOR: + DISABLE_AUG_LIST: ['placeholder'] + AUG_CONFIG_LIST: + - NAME: gt_sampling + USE_ROAD_PLANE: False + DB_INFO_PATH: + - custom_dbinfos_train.pkl + PREPARE: { + filter_by_min_points: ['Car:5', 'Pedestrian:5', 'Cyclist:5'], + filter_by_difficulty: [-1], + } + + SAMPLE_GROUPS: ['Car:20','Pedestrian:15', 'Cyclist:15'] + NUM_POINT_FEATURES: 4 + DATABASE_WITH_FAKELIDAR: False + REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0] + LIMIT_WHOLE_SCENE: True + + - NAME: random_world_flip + ALONG_AXIS_LIST: ['x'] + + - NAME: random_world_rotation + WORLD_ROT_ANGLE: [-0.78539816, 0.78539816] + + - NAME: random_world_scaling + WORLD_SCALE_RANGE: [0.95, 1.05] + +DATA_PROCESSOR: + - NAME: mask_points_and_boxes_outside_range + REMOVE_OUTSIDE_BOXES: True + + - NAME: shuffle_points + SHUFFLE_ENABLED: { + 'train': True, + 'test': False + } + + - NAME: transform_points_to_voxels + VOXEL_SIZE: [0.05, 0.05, 0.1] + MAX_POINTS_PER_VOXEL: 5 + MAX_NUMBER_OF_VOXELS: { + 'train': 16000, + 'test': 40000 + } \ No newline at end of file From cb1ecefaa38a763ef7d3e0de22fb36c89f21509b Mon Sep 17 00:00:00 2001 From: jihanyang Date: Mon, 22 Aug 2022 15:44:17 +0800 Subject: [PATCH 3/5] Modify the custom dataset support --- docs/CUSTOM_DATASET_TUTORIAL.md | 110 ++++++ docs/GETTING_STARTED.md | 2 +- pcdet/datasets/custom/README.md | 37 -- pcdet/datasets/custom/custom_dataset.py | 367 +++++++----------- pcdet/datasets/custom/custom_utils.py | 0 pcdet/datasets/dataset.py | 44 ++- pcdet/datasets/lyft/lyft_dataset.py | 44 --- pcdet/datasets/nuscenes/nuscenes_dataset.py | 45 --- pcdet/datasets/waymo/waymo_dataset.py | 47 --- tools/cfgs/custom_models/pointrcnn.yaml | 161 -------- tools/cfgs/custom_models/pv_rcnn.yaml | 58 +-- .../cfgs/dataset_configs/custom_dataset.yaml | 27 +- 12 files changed, 321 insertions(+), 621 deletions(-) create mode 100644 docs/CUSTOM_DATASET_TUTORIAL.md delete mode 100644 pcdet/datasets/custom/README.md delete mode 100644 pcdet/datasets/custom/custom_utils.py delete mode 100644 tools/cfgs/custom_models/pointrcnn.yaml diff --git a/docs/CUSTOM_DATASET_TUTORIAL.md b/docs/CUSTOM_DATASET_TUTORIAL.md new file mode 100644 index 000000000..0bf0f0c91 --- /dev/null +++ b/docs/CUSTOM_DATASET_TUTORIAL.md @@ -0,0 +1,110 @@ +# Custom Dataset Tutorial +For the custom dataset template, we only consider the basic scenario: raw point clouds and +their corresponding annotations. Point clouds are supposed to be stored in `.npy` format. + +## Label format +We only consider the most basic information -- category and bounding box in the label template. +Annotations are stored in the `.txt`. Each line represents a box in a given scene as below: +``` +[x y z dx dy dz heading_angle category_id] +1.50 1.46 0.10 5.12 1.85 4.13 1.56 0 +5.54 0.57 0.41 1.08 0.74 1.95 1.57 1 +``` +The box should in the unified 3D box definition (see [README](../README.md)) +The correspondence between `category_id` and `category_name` need to be pre-defined. + + +## Files structure +Files should be placed as the following folder structure: +``` +OpenPCDet +├── data +│ ├── custom +│ │ │── ImageSets +│ │ │ │── train.txt +│ │ │ │── val.txt +│ │ │── points +│ │ │ │── 000000.npy +│ │ │ │── 999999.npy +│ │ │── labels +│ │ │ │── 000000.txt +│ │ │ │── 999999.txt +├── pcdet +├── tools +``` +Dataset splits need to be pre-defined and placed in `ImageSets` + +## Hyper-parameters Configurations + +### Point cloud features +Modify following configurations to in `custom_dataset.yaml` to +suit your own point clouds. +```yaml +POINT_FEATURE_ENCODING: { + encoding_type: absolute_coordinates_encoding, + used_feature_list: ['x', 'y', 'z', 'intensity'], + src_feature_list: ['x', 'y', 'z', 'intensity'], +} +... +# In gt_sampling data augmentation +NUM_POINT_FEATURES: 4 + +``` + +#### Point cloud range and voxel sizes +For voxel based detectors such as SECOND, PV-RCNN and CenterPoint, the point cloud range and voxel size should follow: +1. Point cloud range along z-axis / voxel_size is 40 +2. Point cloud range along x&y-axis / voxel_size is the multiple of 16. + +Notice that the second rule also suit pillar based detectors such as PointPillar and CenterPoint-Pillar. + +### Category names and anchor sizes +Category names and anchor size are need to be adapted to custom datasets. + ```yaml +CLASS_NAMES: ['Vehicle', 'Pedestrian', 'Cyclist'] +... +MAP_CLASS_TO_KITTI: { + 'Vehicle': 'Car', + 'Pedestrian': 'Pedestrian', + 'Cyclist': 'Cyclist', +} +... +'anchor_sizes': [[3.9, 1.6, 1.56]], +... +# In gt sampling data augmentation +PREPARE: { + filter_by_min_points: ['Vehicle:5', 'Pedestrian:5', 'Cyclist:5'], + filter_by_difficulty: [-1], +} +SAMPLE_GROUPS: ['Vehicle:20','Pedestrian:15', 'Cyclist:15'] +... + ``` +In addition, please also modify the default category names for creating infos in `custom_dataset.py` +``` +create_custom_infos( + dataset_cfg=dataset_cfg, + class_names=['Vehicle', 'Pedestrian', 'Cyclist'], + data_path=ROOT_DIR / 'data' / 'custom', + save_path=ROOT_DIR / 'data' / 'custom', +) +``` + + +## Create data info +Generate the data infos by running the following command: +```shell +python -m pcdet.datasets.custom.custom_dataset create_custom_infos tools/cfgs/dataset_configs/custom_dataset.yaml +``` + + +## Evaluation +Here, we only provide an implementation for KITTI stype evaluation. +The category mapping between custom dataset and KITTI need to be defined +in the `custom_dataset.yaml` +```yaml +MAP_CLASS_TO_KITTI: { + 'Vehicle': 'Car', + 'Pedestrian': 'Pedestrian', + 'Cyclist': 'Cyclist', +} +``` diff --git a/docs/GETTING_STARTED.md b/docs/GETTING_STARTED.md index d1e160965..e34fee953 100644 --- a/docs/GETTING_STARTED.md +++ b/docs/GETTING_STARTED.md @@ -5,7 +5,7 @@ and the model configs are located within [tools/cfgs](../tools/cfgs) for differe ## Dataset Preparation -Currently we provide the dataloader of KITTI dataset and NuScenes dataset, and the supporting of more datasets are on the way. +Currently we provide the dataloader of KITTI, NuScenes, Waymo, Lyft and Pandaset. If you want to use a custom dataset, Please refer to our [custom dataset template](CUSTOM_DATASET_TUTORIAL.md). ### KITTI Dataset * Please download the official [KITTI 3D object detection](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d) dataset and organize the downloaded files as follows (the road planes could be downloaded from [[road plane]](https://drive.google.com/file/d/1d5mq0RXRnvHPVeKx6Q612z0YRO1t2wAp/view?usp=sharing), which are optional for data augmentation in the training): diff --git a/pcdet/datasets/custom/README.md b/pcdet/datasets/custom/README.md deleted file mode 100644 index 398fa4c83..000000000 --- a/pcdet/datasets/custom/README.md +++ /dev/null @@ -1,37 +0,0 @@ -For Custom Dataset using -## Custom Dataset -For pure point cloud dataset, which means you don't have images generated when got point cloud data from a self-defined scene. Label those raw data and make sure label files to be kitti-like: -``` -Car 0 0 0 0 0 0 0 1.50 1.46 3.70 -5.12 1.85 4.13 1.56 -Pedestrian 0 0 0 0 0 0 0 1.54 0.57 0.41 -7.92 1.94 15.95 1.57 -DontCare 0 0 0 0 0 0 0 -1 -1 -1 -1000 -1000 -1000 -10 -``` -Some items (which is shown from the first zero to the seventh zero above) are not necessary because they are meaningless if no cameras. And the `image` folder, `calib` folder are both needless, which will be much more convenient for not using just official dataset. The point cloud dataset should be `.bin` format. - -Place the custom dataset: -``` -OpenPCDet -├── data -│ ├── custom -│ │ │── ImageSets -│ │ │── training -│ │ │ ├──velodyne & label_2 -│ │ │── testing -│ │ │ ├──velodyne -├── pcdet -├── tools -``` -## Calibration -Calibration rules for cameras are not need. But you need to define how to transform from KITTI coordinates to lidar coordinates. The lidar coordinates are the custom coordinates. The raw data are in lidar coordinates and the labels are in KITTI coordinates. This self-defined transform is written in `custom_dataset->get_calib (188)` which is used to get gt_boxes from labels. -## Other configurations -Possible other parameters or names that need to be check to adapt the custom scene. -- config files - ``` - CLASS_NAMES: ['Car', 'Pedestrian', 'Cyclist'] # pv_rcnn.yaml - ... -'anchor_sizes': [[3.9, 1.6, 1.56]], # pv_rcnn.yaml -... -POINT_CLOUD_RANGE: [-70.4, -40, -3, 70.4, 40, 1] # custom_dataset.yaml -... - ``` -The train, test and pred are all the same as others. \ No newline at end of file diff --git a/pcdet/datasets/custom/custom_dataset.py b/pcdet/datasets/custom/custom_dataset.py index 1c6ff705a..9552074aa 100644 --- a/pcdet/datasets/custom/custom_dataset.py +++ b/pcdet/datasets/custom/custom_dataset.py @@ -3,15 +3,14 @@ import os import numpy as np -from skimage import io -from . import custom_utils from ...ops.roiaware_pool3d import roiaware_pool3d_utils -from ...utils import box_utils, common_utils, object3d_custom +from ...utils import box_utils, common_utils from ..dataset import DatasetTemplate + class CustomDataset(DatasetTemplate): - def __init__(self, dataset_cfg, class_names, training=True, root_path=None, logger=None, ext='.bin'): + def __init__(self, dataset_cfg, class_names, training=True, root_path=None, logger=None): """ Args: root_path: @@ -24,19 +23,16 @@ def __init__(self, dataset_cfg, class_names, training=True, root_path=None, logg dataset_cfg=dataset_cfg, class_names=class_names, training=training, root_path=root_path, logger=logger ) self.split = self.dataset_cfg.DATA_SPLIT[self.mode] - self.root_split_path = os.path.join(self.root_path, ('training' if self.split != 'test' else 'testing')) - split_dir = os.path.join(self.root_path, 'ImageSets',(self.split + '.txt')) + split_dir = os.path.join(self.root_path, 'ImageSets', (self.split + '.txt')) self.sample_id_list = [x.strip() for x in open(split_dir).readlines()] if os.path.exists(split_dir) else None self.custom_infos = [] - self.include_custom_data(self.mode) - self.ext = ext - + self.include_data(self.mode) + self.map_class_to_kitti = self.dataset_cfg.MAP_CLASS_TO_KITTI - def include_custom_data(self, mode): - if self.logger is not None: - self.logger.info('Loading Custom dataset.') + def include_data(self, mode): + self.logger.info('Loading Custom dataset.') custom_infos = [] for info_path in self.dataset_cfg.INFO_PATH[mode]: @@ -46,104 +42,124 @@ def include_custom_data(self, mode): with open(info_path, 'rb') as f: infos = pickle.load(f) custom_infos.extend(infos) - + self.custom_infos.extend(custom_infos) + self.logger.info('Total samples for CUSTOM dataset: %d' % (len(custom_infos))) + + def get_label(self, idx): + label_file = self.root_path / 'labels' / ('%s.txt' % idx) + assert label_file.exists() + with open(label_file, 'r') as f: + lines = f.readlines() + + # [N, 8]: (x y z dx dy dz heading_angle category_id) + gt_boxes = [line.strip().split(' ') for line in lines] + return np.array(gt_boxes, dtype=np.float32) + + def get_lidar(self, idx): + lidar_file = self.root_path / 'points' / ('%s.npy' % idx) + assert lidar_file.exists() + point_features = np.load(lidar_file) + return point_features + + def set_split(self, split): + super().__init__( + dataset_cfg=self.dataset_cfg, class_names=self.class_names, training=self.training, + root_path=self.root_path, logger=self.logger + ) + self.split = split + + split_dir = self.root_path / 'ImageSets' / (self.split + '.txt') + self.sample_id_list = [x.strip() for x in open(split_dir).readlines()] if split_dir.exists() else None + + def __len__(self): + if self._merge_all_iters_to_one_epoch: + return len(self.sample_id_list) * self.total_epochs + + return len(self.custom_infos) + + def __getitem__(self, index): + if self._merge_all_iters_to_one_epoch: + index = index % len(self.custom_infos) + + info = copy.deepcopy(self.custom_infos[index]) + sample_idx = info['point_cloud']['lidar_idx'] + points = self.get_lidar(sample_idx) + input_dict = { + 'frame_id': self.sample_id_list[index], + 'points': points + } + + if 'annos' in info: + annos = info['annos'] + annos = common_utils.drop_info_with_name(annos, name='DontCare') + gt_names = annos['name'] + gt_boxes_lidar = annos['gt_boxes_lidar'] + input_dict.update({ + 'gt_names': gt_names, + 'gt_boxes': gt_boxes_lidar + }) + + data_dict = self.prepare_data(data_dict=input_dict) + + return data_dict + + def evaluation(self, det_annos, class_names, **kwargs): + if 'annos' not in self.custom_infos[0].keys(): + return 'No ground-truth boxes for evaluation', {} + + def kitti_eval(eval_det_annos, eval_gt_annos, map_name_to_kitti): + from ..kitti.kitti_object_eval_python import eval as kitti_eval + from ..kitti import kitti_utils + + kitti_utils.transform_annotations_to_kitti_format(eval_det_annos, map_name_to_kitti=map_name_to_kitti) + kitti_utils.transform_annotations_to_kitti_format( + eval_gt_annos, map_name_to_kitti=map_name_to_kitti, + info_with_fakelidar=self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False) + ) + kitti_class_names = [map_name_to_kitti[x] for x in class_names] + ap_result_str, ap_dict = kitti_eval.get_official_eval_result( + gt_annos=eval_gt_annos, dt_annos=eval_det_annos, current_classes=kitti_class_names + ) + return ap_result_str, ap_dict + + eval_det_annos = copy.deepcopy(det_annos) + eval_gt_annos = [copy.deepcopy(info['annos']) for info in self.custom_infos] + + if kwargs['eval_metric'] == 'kitti': + ap_result_str, ap_dict = kitti_eval(eval_det_annos, eval_gt_annos, self.map_class_to_kitti) + else: + raise NotImplementedError - if self.logger is not None: - self.logger.info('Total samples for CUSTOM dataset: %d' % (len(custom_infos))) - + return ap_result_str, ap_dict - def get_infos(self, num_workers=16, has_label=True, count_inside_pts=True, sample_id_list=None): + def get_infos(self, class_names, num_workers=4, has_label=True, sample_id_list=None, num_features=4): import concurrent.futures as futures - # Process single scene + class_names = np.array(class_names) + def process_single_scene(sample_idx): print('%s sample_idx: %s' % (self.split, sample_idx)) info = {} - pc_info = {'num_features': 4, 'lidar_idx': sample_idx} + pc_info = {'num_features': num_features, 'lidar_idx': sample_idx} info['point_cloud'] = pc_info - # no images, calibs are need to transform the labels - - type_to_id = {'Car': 1, 'Pedestrian': 2, 'Cyclist': 3} if has_label: - obj_list = self.get_label(sample_idx) annotations = {} - annotations['name'] = np.array([obj.cls_type for obj in obj_list]) # 1-dimension - annotations['dimensions'] = np.array([[obj.l, obj.h, obj.w] for obj in obj_list]) - annotations['location'] = np.concatenate([obj.loc.reshape(1,3) for obj in obj_list]) - annotations['rotation_y'] = np.array([obj.ry for obj in obj_list]) # 1-dimension - - num_objects = len([obj.cls_type for obj in obj_list if obj.cls_type != 'DontCare']) - num_gt = len(annotations['name']) - index = list(range(num_objects)) + [-1] * (num_gt - num_objects) - annotations['index'] = np.array(index, dtype=np.int32) - - loc = annotations['location'][:num_objects] - dims = annotations['dimensions'][:num_objects] - rots = annotations['rotation_y'][:num_objects] - loc_lidar = self.get_calib(loc) - l, h, w = dims[:, 0:1], dims[:, 1:2], dims[:, 2:3] - gt_boxes_lidar = np.concatenate([loc_lidar, l, w, h, (np.pi / 2 - rots[..., np.newaxis])], axis=1) # 2-dimension array - annotations['gt_boxes_lidar'] = gt_boxes_lidar - + gt_boxes_lidar = self.get_label(sample_idx) + annotations['name'] = class_names[gt_boxes_lidar[:, -1].astype(np.int64)] + annotations['gt_boxes_lidar'] = gt_boxes_lidar[:, :7] info['annos'] = annotations - + return info - + sample_id_list = sample_id_list if sample_id_list is not None else self.sample_id_list # create a thread pool to improve the velocity with futures.ThreadPoolExecutor(num_workers) as executor: infos = executor.map(process_single_scene, sample_id_list) - return list(infos) - - - def get_calib(self, loc): - """ - This calibration is different from the kitti dataset. - You should check or redefine it according to your condition. - """ - loc_lidar = np.concatenate([np.array((float(loc_obj[2]), float(-loc_obj[0]), float(loc_obj[1]-2.3)), dtype=np.float32).reshape(1,3) for loc_obj in loc]) - return loc_lidar - - - def get_label(self, idx): - - label_file = self.root_split_path / 'label_2' / ('%s.txt' % idx) - assert label_file.exists() - return object3d_custom.get_objects_from_label(label_file) - - def get_lidar(self, idx, getitem): - """ - Loads point clouds for a sample - Args: - index (int): Index of the point cloud file to get. - Returns: - np.array(N, 4): point cloud. - """ - # get lidar statistics - if getitem == True: - lidar_file = self.root_split_path + '/velodyne/' + ('%s.bin' % idx) - else: - lidar_file = self.root_split_path / 'velodyne' / ('%s.bin' % idx) - return np.fromfile(str(lidar_file), dtype=np.float32).reshape(-1, 4) - - - def set_split(self, split): - super().__init__( - dataset_cfg=self.dataset_cfg, class_names=self.class_names, training=self.training, root_path=self.root_path, logger=self.logger - ) - self.split = split - self.root_split_path = self.root_path / ('training' if self.split != 'test' else 'testing') - - split_dir = self.root_path / 'ImageSets' / (self.split + '.txt') - self.sample_id_list = [x.strip() for x in open(split_dir).readlines()] if split_dir.exists() else None - - - # Create gt database for data augmentation def create_groundtruth_database(self, info_path=None, used_classes=None, split='train'): import torch @@ -156,12 +172,11 @@ def create_groundtruth_database(self, info_path=None, used_classes=None, split=' with open(info_path, 'rb') as f: infos = pickle.load(f) - # For each .bin file for k in range(len(infos)): print('gt_database sample: %d/%d' % (k + 1, len(infos))) info = infos[k] sample_idx = info['point_cloud']['lidar_idx'] - points = self.get_lidar(sample_idx, False) + points = self.get_lidar(sample_idx) annos = info['annos'] names = annos['name'] gt_boxes = annos['gt_boxes_lidar'] @@ -197,169 +212,69 @@ def create_groundtruth_database(self, info_path=None, used_classes=None, split=' pickle.dump(all_db_infos, f) @staticmethod - def generate_prediction_dicts(batch_dict, pred_dicts, class_names, output_path=None): - """ - Args: - batch_dict: - frame_id: - pred_dicts: list of pred_dicts - pred_boxes: (N,7), Tensor - pred_scores: (N), Tensor - pred_lables: (N), Tensor - class_names: - output_path: - - Returns: - - """ - def get_template_prediction(num_smaples): - ret_dict = { - 'name': np.zeros(num_smaples), 'alpha' : np.zeros(num_smaples), - 'dimensions': np.zeros([num_smaples, 3]), 'location': np.zeros([num_smaples, 3]), - 'rotation_y': np.zero(num_smaples), 'score': np.zeros(num_smaples), - 'boxes_lidar': np.zeros([num_smaples, 7]) - } - return ret_dict - - def generate_single_sample_dict(batch_index, box_dict): - pred_scores = box_dict['pred_scores'].cpu().numpy() - pred_boxes = box_dict['pred_boxes'].cpu().numpy() - pred_labels = box_dict['pred_labels'].cpu().numpy() - - # Define an empty template dict to store the prediction information, 'pred_scores.shape[0]' means 'num_samples' - pred_dict = get_template_prediction(pred_scores.shape[0]) - # If num_samples equals zero then return the empty dict - if pred_scores.shape[0] == 0: - return pred_dict - - # No calibration files - - pred_boxes_camera = box_utils.boxes3d_lidar_to_kitti_camera[pred_boxes] - - pred_dict['name'] = np.array(class_names)[pred_labels - 1] - pred_dict['alpha'] = -np.arctan2(-pred_boxes[:, 1], pred_boxes[:, 0]) + pred_boxes_camera[:, 6] - pred_dict['dimensions'] = pred_boxes_camera[:, 3:6] - pred_dict['location'] = pred_boxes_camera[:, 0:3] - pred_dict['rotation_y'] = pred_boxes_camera[:, 6] - pred_dict['score'] = pred_scores - pred_dict['boxes_lidar'] = pred_boxes - - return pred_dict - - annos = [] - for index, box_dict in enumerate(pred_dicts): - frame_id = batch_dict['frame_id'][index] - - single_pred_dict = generate_single_sample_dict(index, box_dict) - single_pred_dict['frame_id'] = frame_id - annos.append(single_pred_dict) - - # Output pred results to Output-path in .txt file - if output_path is not None: - cur_det_file = output_path / ('%s.txt' % frame_id) - with open(cur_det_file, 'w') as f: - bbox = single_pred_dict['bbox'] - loc = single_pred_dict['location'] - dims = single_pred_dict['dimensions'] # lhw -> hwl: lidar -> camera - - for idx in range(len(bbox)): - print('%s -1 -1 %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f' - % (single_pred_dict['name'][idx], single_pred_dict['alpha'][idx], - bbox[idx][0], bbox[idx][1], bbox[idx][2], bbox[idx][3], - dims[idx][1], dims[idx][2], dims[idx][0], loc[idx][0], - loc[idx][1], loc[idx][2], single_pred_dict['rotation_y'][idx], - single_pred_dict['score'][idx]), file=f) - return annos - - - def __len__(self): - if self._merge_all_iters_to_one_epoch: - return len(self.sample_id_list) * self.total_epochs - - return len(self.custom_infos) - - - def __getitem__(self, index): - """ - Function: - Read 'velodyne' folder as pointclouds - Read 'label_2' folder as labels - Return type 'dict' - """ - if self._merge_all_iters_to_one_epoch: - index = index % len(self.custom_infos) - - info = copy.deepcopy(self.custom_infos[index]) - - sample_idx = info['point_cloud']['lidar_idx'] - get_item_list = self.dataset_cfg.get('GET_ITEM_LIST', ['points']) - - input_dict = { - 'frame_id': self.sample_id_list[index], - } - - """ - Here infos was generated by get_infos - """ - if 'annos' in info: - annos = info['annos'] - annos = common_utils.drop_info_with_name(annos, name='DontCare') - loc, dims, rots = annos['location'], annos['dimensions'], annos['rotation_y'] - gt_names = annos['name'] - gt_boxes_lidar = annos['gt_boxes_lidar'] - - if 'points' in get_item_list: - points = self.get_lidar(sample_idx, True) - input_dict['points'] = points - input_dict.update({ - 'gt_names': gt_names, - 'gt_boxes': gt_boxes_lidar - }) - - data_dict = self.prepare_data(data_dict=input_dict) - return data_dict + def create_label_file_with_name_and_box(class_names, gt_names, gt_boxes, save_label_path): + with open(save_label_path, 'w') as f: + for idx in range(gt_boxes.shape[0]): + boxes = gt_boxes[idx] + name = gt_names[idx] + if name not in class_names: + continue + category_id = class_names.index(name) + line = "{x} {y} {z} {l} {w} {h} {angle} {category_id}\n".format( + x=boxes[0], y=boxes[1], z=(boxes[2]), l=boxes[3], + w=boxes[4], h=boxes[5], angle=boxes[6], category_id=category_id + ) + f.write(line) def create_custom_infos(dataset_cfg, class_names, data_path, save_path, workers=4): - dataset = CustomDataset(dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path, training=False) + dataset = CustomDataset( + dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path, + training=False, logger=common_utils.create_logger() + ) train_split, val_split = 'train', 'val' + num_features = len(dataset_cfg.POINT_FEATURE_ENCODING.src_feature_list) - # No evaluation train_filename = save_path / ('custom_infos_%s.pkl' % train_split) - val_filenmae = save_path / ('custom_infos%s.pkl' % val_split) - trainval_filename = save_path / 'custom_infos_trainval.pkl' - test_filename = save_path / 'custom_infos_test.pkl' + val_filename = save_path / ('custom_infos_%s.pkl' % val_split) print('------------------------Start to generate data infos------------------------') dataset.set_split(train_split) - custom_infos_train = dataset.get_infos(num_workers=workers, has_label=True, count_inside_pts=True) + custom_infos_train = dataset.get_infos( + class_names, num_workers=workers, has_label=True, num_features=num_features + ) with open(train_filename, 'wb') as f: pickle.dump(custom_infos_train, f) print('Custom info train file is save to %s' % train_filename) - dataset.set_split('test') - custom_infos_test = dataset.get_infos(num_workers=workers, has_label=False, count_inside_pts=False) - with open(test_filename, 'wb') as f: - pickle.dump(custom_infos_test, f) - print('Custom info test file is saved to %s' % test_filename) + dataset.set_split(val_split) + custom_infos_val = dataset.get_infos( + class_names, num_workers=workers, has_label=True, num_features=num_features + ) + with open(val_filename, 'wb') as f: + pickle.dump(custom_infos_val, f) + print('Custom info train file is save to %s' % val_filename) print('------------------------Start create groundtruth database for data augmentation------------------------') dataset.set_split(train_split) dataset.create_groundtruth_database(train_filename, split=train_split) print('------------------------Data preparation done------------------------') -if __name__=='__main__': + +if __name__ == '__main__': import sys + if sys.argv.__len__() > 1 and sys.argv[1] == 'create_custom_infos': import yaml from pathlib import Path from easydict import EasyDict + dataset_cfg = EasyDict(yaml.safe_load(open(sys.argv[2]))) ROOT_DIR = (Path(__file__).resolve().parent / '../../../').resolve() create_custom_infos( dataset_cfg=dataset_cfg, - class_names=['Car', 'Pedestrian', 'Cyclist'], + class_names=['Vehicle', 'Pedestrian', 'Cyclist'], data_path=ROOT_DIR / 'data' / 'custom', - save_path=ROOT_DIR / 'data' / 'custom' + save_path=ROOT_DIR / 'data' / 'custom', ) diff --git a/pcdet/datasets/custom/custom_utils.py b/pcdet/datasets/custom/custom_utils.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pcdet/datasets/dataset.py b/pcdet/datasets/dataset.py index cf6d9fe42..2be1990b2 100644 --- a/pcdet/datasets/dataset.py +++ b/pcdet/datasets/dataset.py @@ -9,6 +9,7 @@ from .processor.data_processor import DataProcessor from .processor.point_feature_encoder import PointFeatureEncoder + class DatasetTemplate(torch_data.Dataset): def __init__(self, dataset_cfg=None, class_names=None, training=True, root_path=None, logger=None): super().__init__() @@ -59,21 +60,52 @@ def __setstate__(self, d): @staticmethod def generate_prediction_dicts(batch_dict, pred_dicts, class_names, output_path=None): """ - To support a custom dataset, implement this function to receive the predicted results from the model, and then - transform the unified normative coordinate to your required coordinate, and optionally save them to disk. - Args: - batch_dict: dict of original data from the dataloader - pred_dicts: dict of predicted results from the model + batch_dict: + frame_id: + pred_dicts: list of pred_dicts pred_boxes: (N, 7), Tensor pred_scores: (N), Tensor pred_labels: (N), Tensor class_names: - output_path: if it is not None, save the results to this path + output_path: + Returns: """ + def get_template_prediction(num_samples): + ret_dict = { + 'name': np.zeros(num_samples), 'score': np.zeros(num_samples), + 'boxes_lidar': np.zeros([num_samples, 7]), 'pred_labels': np.zeros(num_samples) + } + return ret_dict + + def generate_single_sample_dict(box_dict): + pred_scores = box_dict['pred_scores'].cpu().numpy() + pred_boxes = box_dict['pred_boxes'].cpu().numpy() + pred_labels = box_dict['pred_labels'].cpu().numpy() + pred_dict = get_template_prediction(pred_scores.shape[0]) + if pred_scores.shape[0] == 0: + return pred_dict + + pred_dict['name'] = np.array(class_names)[pred_labels - 1] + pred_dict['score'] = pred_scores + pred_dict['boxes_lidar'] = pred_boxes + pred_dict['pred_labels'] = pred_labels + + return pred_dict + + annos = [] + for index, box_dict in enumerate(pred_dicts): + single_pred_dict = generate_single_sample_dict(box_dict) + single_pred_dict['frame_id'] = batch_dict['frame_id'][index] + if 'metadata' in batch_dict: + single_pred_dict['metadata'] = batch_dict['metadata'][index] + annos.append(single_pred_dict) + + return annos + def merge_all_iters_to_one_epoch(self, merge=True, epochs=None): if merge: self._merge_all_iters_to_one_epoch = True diff --git a/pcdet/datasets/lyft/lyft_dataset.py b/pcdet/datasets/lyft/lyft_dataset.py index a042eb56f..4fd197acd 100644 --- a/pcdet/datasets/lyft/lyft_dataset.py +++ b/pcdet/datasets/lyft/lyft_dataset.py @@ -106,50 +106,6 @@ def __getitem__(self, index): return data_dict - def generate_prediction_dicts(self, batch_dict, pred_dicts, class_names, output_path=None): - """ - Args: - batch_dict: - frame_id: - pred_dicts: list of pred_dicts - pred_boxes: (N, 7), Tensor - pred_scores: (N), Tensor - pred_labels: (N), Tensor - class_names: - output_path: - Returns: - """ - def get_template_prediction(num_samples): - ret_dict = { - 'name': np.zeros(num_samples), 'score': np.zeros(num_samples), - 'boxes_lidar': np.zeros([num_samples, 7]), 'pred_labels': np.zeros(num_samples) - } - return ret_dict - - def generate_single_sample_dict(box_dict): - pred_scores = box_dict['pred_scores'].cpu().numpy() - pred_boxes = box_dict['pred_boxes'].cpu().numpy() - pred_labels = box_dict['pred_labels'].cpu().numpy() - pred_dict = get_template_prediction(pred_scores.shape[0]) - if pred_scores.shape[0] == 0: - return pred_dict - - pred_dict['name'] = np.array(class_names)[pred_labels - 1] - pred_dict['score'] = pred_scores - pred_dict['boxes_lidar'] = pred_boxes - pred_dict['pred_labels'] = pred_labels - - return pred_dict - - annos = [] - for index, box_dict in enumerate(pred_dicts): - single_pred_dict = generate_single_sample_dict(box_dict) - single_pred_dict['frame_id'] = batch_dict['frame_id'][index] - single_pred_dict['metadata'] = batch_dict['metadata'][index] - annos.append(single_pred_dict) - - return annos - def kitti_eval(self, eval_det_annos, eval_gt_annos, class_names): from ..kitti.kitti_object_eval_python import eval as kitti_eval from ..kitti import kitti_utils diff --git a/pcdet/datasets/nuscenes/nuscenes_dataset.py b/pcdet/datasets/nuscenes/nuscenes_dataset.py index 855c2605d..15506e00f 100644 --- a/pcdet/datasets/nuscenes/nuscenes_dataset.py +++ b/pcdet/datasets/nuscenes/nuscenes_dataset.py @@ -150,51 +150,6 @@ def __getitem__(self, index): return data_dict - @staticmethod - def generate_prediction_dicts(batch_dict, pred_dicts, class_names, output_path=None): - """ - Args: - batch_dict: - frame_id: - pred_dicts: list of pred_dicts - pred_boxes: (N, 7), Tensor - pred_scores: (N), Tensor - pred_labels: (N), Tensor - class_names: - output_path: - Returns: - """ - def get_template_prediction(num_samples): - ret_dict = { - 'name': np.zeros(num_samples), 'score': np.zeros(num_samples), - 'boxes_lidar': np.zeros([num_samples, 7]), 'pred_labels': np.zeros(num_samples) - } - return ret_dict - - def generate_single_sample_dict(box_dict): - pred_scores = box_dict['pred_scores'].cpu().numpy() - pred_boxes = box_dict['pred_boxes'].cpu().numpy() - pred_labels = box_dict['pred_labels'].cpu().numpy() - pred_dict = get_template_prediction(pred_scores.shape[0]) - if pred_scores.shape[0] == 0: - return pred_dict - - pred_dict['name'] = np.array(class_names)[pred_labels - 1] - pred_dict['score'] = pred_scores - pred_dict['boxes_lidar'] = pred_boxes - pred_dict['pred_labels'] = pred_labels - - return pred_dict - - annos = [] - for index, box_dict in enumerate(pred_dicts): - single_pred_dict = generate_single_sample_dict(box_dict) - single_pred_dict['frame_id'] = batch_dict['frame_id'][index] - single_pred_dict['metadata'] = batch_dict['metadata'][index] - annos.append(single_pred_dict) - - return annos - def evaluation(self, det_annos, class_names, **kwargs): import json from nuscenes.nuscenes import NuScenes diff --git a/pcdet/datasets/waymo/waymo_dataset.py b/pcdet/datasets/waymo/waymo_dataset.py index 355fcbca9..3dcdc7d5c 100644 --- a/pcdet/datasets/waymo/waymo_dataset.py +++ b/pcdet/datasets/waymo/waymo_dataset.py @@ -218,53 +218,6 @@ def __getitem__(self, index): data_dict.pop('num_points_in_gt', None) return data_dict - @staticmethod - def generate_prediction_dicts(batch_dict, pred_dicts, class_names, output_path=None): - """ - Args: - batch_dict: - frame_id: - pred_dicts: list of pred_dicts - pred_boxes: (N, 7), Tensor - pred_scores: (N), Tensor - pred_labels: (N), Tensor - class_names: - output_path: - - Returns: - - """ - - def get_template_prediction(num_samples): - ret_dict = { - 'name': np.zeros(num_samples), 'score': np.zeros(num_samples), - 'boxes_lidar': np.zeros([num_samples, 7]) - } - return ret_dict - - def generate_single_sample_dict(box_dict): - pred_scores = box_dict['pred_scores'].cpu().numpy() - pred_boxes = box_dict['pred_boxes'].cpu().numpy() - pred_labels = box_dict['pred_labels'].cpu().numpy() - pred_dict = get_template_prediction(pred_scores.shape[0]) - if pred_scores.shape[0] == 0: - return pred_dict - - pred_dict['name'] = np.array(class_names)[pred_labels - 1] - pred_dict['score'] = pred_scores - pred_dict['boxes_lidar'] = pred_boxes - - return pred_dict - - annos = [] - for index, box_dict in enumerate(pred_dicts): - single_pred_dict = generate_single_sample_dict(box_dict) - single_pred_dict['frame_id'] = batch_dict['frame_id'][index] - single_pred_dict['metadata'] = batch_dict['metadata'][index] - annos.append(single_pred_dict) - - return annos - def evaluation(self, det_annos, class_names, **kwargs): if 'annos' not in self.infos[0].keys(): return 'No ground-truth boxes for evaluation', {} diff --git a/tools/cfgs/custom_models/pointrcnn.yaml b/tools/cfgs/custom_models/pointrcnn.yaml deleted file mode 100644 index 2df123581..000000000 --- a/tools/cfgs/custom_models/pointrcnn.yaml +++ /dev/null @@ -1,161 +0,0 @@ -CLASS_NAMES: ['Car'] -# CLASS_NAMES: ['Car', 'Pedestrian', 'Cyclist'] - -DATA_CONFIG: - _BASE_CONFIG_: ../dataset_configs/custom_dataset.yaml - - DATA_PROCESSOR: - - NAME: mask_points_and_boxes_outside_range - REMOVE_OUTSIDE_BOXES: True - - - NAME: sample_points - NUM_POINTS: { - 'train': 16384, - 'test': 16384 - } - - - NAME: shuffle_points - SHUFFLE_ENABLED: { - 'train': True, - 'test': False - } - -MODEL: - NAME: PointRCNN - - BACKBONE_3D: - NAME: PointNet2MSG - SA_CONFIG: - NPOINTS: [4096, 1024, 256, 64] - RADIUS: [[0.1, 0.5], [0.5, 1.0], [1.0, 2.0], [2.0, 4.0]] - NSAMPLE: [[16, 32], [16, 32], [16, 32], [16, 32]] - MLPS: [[[16, 16, 32], [32, 32, 64]], - [[64, 64, 128], [64, 96, 128]], - [[128, 196, 256], [128, 196, 256]], - [[256, 256, 512], [256, 384, 512]]] - FP_MLPS: [[128, 128], [256, 256], [512, 512], [512, 512]] - - POINT_HEAD: - NAME: PointHeadBox - CLS_FC: [256, 256] - REG_FC: [256, 256] - CLASS_AGNOSTIC: False - USE_POINT_FEATURES_BEFORE_FUSION: False - TARGET_CONFIG: - GT_EXTRA_WIDTH: [0.2, 0.2, 0.2] - BOX_CODER: PointResidualCoder - BOX_CODER_CONFIG: { - 'use_mean_size': True, - 'mean_size': [ - [3.9, 1.6, 1.56], - [0.8, 0.6, 1.73], - [1.76, 0.6, 1.73] - ] - } - - LOSS_CONFIG: - LOSS_REG: WeightedSmoothL1Loss - LOSS_WEIGHTS: { - 'point_cls_weight': 1.0, - 'point_box_weight': 1.0, - 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] - } - - ROI_HEAD: - NAME: PointRCNNHead - CLASS_AGNOSTIC: True - - ROI_POINT_POOL: - POOL_EXTRA_WIDTH: [0.0, 0.0, 0.0] - NUM_SAMPLED_POINTS: 512 - DEPTH_NORMALIZER: 70.0 - - XYZ_UP_LAYER: [128, 128] - CLS_FC: [256, 256] - REG_FC: [256, 256] - DP_RATIO: 0.0 - USE_BN: False - - SA_CONFIG: - NPOINTS: [128, 32, -1] - RADIUS: [0.2, 0.4, 100] - NSAMPLE: [16, 16, 16] - MLPS: [[128, 128, 128], - [128, 128, 256], - [256, 256, 512]] - - NMS_CONFIG: - TRAIN: - NMS_TYPE: nms_gpu - MULTI_CLASSES_NMS: False - NMS_PRE_MAXSIZE: 9000 - NMS_POST_MAXSIZE: 512 - NMS_THRESH: 0.8 - TEST: - NMS_TYPE: nms_gpu - MULTI_CLASSES_NMS: False - NMS_PRE_MAXSIZE: 9000 - NMS_POST_MAXSIZE: 100 - NMS_THRESH: 0.85 - - TARGET_CONFIG: - BOX_CODER: ResidualCoder - ROI_PER_IMAGE: 128 - FG_RATIO: 0.5 - - SAMPLE_ROI_BY_EACH_CLASS: True - CLS_SCORE_TYPE: cls - - CLS_FG_THRESH: 0.6 - CLS_BG_THRESH: 0.45 - CLS_BG_THRESH_LO: 0.1 - HARD_BG_RATIO: 0.8 - - REG_FG_THRESH: 0.55 - - LOSS_CONFIG: - CLS_LOSS: BinaryCrossEntropy - REG_LOSS: smooth-l1 - CORNER_LOSS_REGULARIZATION: True - LOSS_WEIGHTS: { - 'rcnn_cls_weight': 1.0, - 'rcnn_reg_weight': 1.0, - 'rcnn_corner_weight': 1.0, - 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] - } - - POST_PROCESSING: - RECALL_THRESH_LIST: [0.3, 0.5, 0.7] - SCORE_THRESH: 0.1 - OUTPUT_RAW_SCORE: False - - EVAL_METRIC: kitti - - NMS_CONFIG: - MULTI_CLASSES_NMS: False - NMS_TYPE: nms_gpu - NMS_THRESH: 0.1 - NMS_PRE_MAXSIZE: 4096 - NMS_POST_MAXSIZE: 500 - - -OPTIMIZATION: - BATCH_SIZE_PER_GPU: 2 - NUM_EPOCHS: 80 - - OPTIMIZER: adam_onecycle - LR: 0.01 - WEIGHT_DECAY: 0.01 - MOMENTUM: 0.9 - - MOMS: [0.95, 0.85] - PCT_START: 0.4 - DIV_FACTOR: 10 - DECAY_STEP_LIST: [35, 45] - LR_DECAY: 0.1 - LR_CLIP: 0.0000001 - - LR_WARMUP: False - WARMUP_EPOCH: 1 - - GRAD_NORM_CLIP: 10 diff --git a/tools/cfgs/custom_models/pv_rcnn.yaml b/tools/cfgs/custom_models/pv_rcnn.yaml index 012d12ac9..99afcb367 100644 --- a/tools/cfgs/custom_models/pv_rcnn.yaml +++ b/tools/cfgs/custom_models/pv_rcnn.yaml @@ -1,33 +1,7 @@ -CLASS_NAMES: ['Car', 'Pedestrian', 'Cyclist'] +CLASS_NAMES: ['Vehicle', 'Pedestrian', 'Cyclist'] DATA_CONFIG: _BASE_CONFIG_: ../dataset_configs/custom_dataset.yaml - DATA_AUGMENTOR: - DISABLE_AUG_LIST: ['placeholder'] - AUG_CONFIG_LIST: - - NAME: gt_sampling - USE_ROAD_PLANE: False - DB_INFO_PATH: - - custom_dbinfos_train.pkl - PREPARE: { - filter_by_min_points: ['Car:5', 'Pedestrian:5', 'Cyclist:5'], - filter_by_difficulty: [-1], - } - - SAMPLE_GROUPS: ['Car:15','Pedestrian:10', 'Cyclist:10'] - NUM_POINT_FEATURES: 4 - DATABASE_WITH_FAKELIDAR: False - REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0] - LIMIT_WHOLE_SCENE: False - - - NAME: random_world_flip - ALONG_AXIS_LIST: ['x'] - - - NAME: random_world_rotation - WORLD_ROT_ANGLE: [-0.78539816, 0.78539816] - - - NAME: random_world_scaling - WORLD_SCALE_RANGE: [0.95, 1.05] MODEL: NAME: PVRCNN @@ -62,20 +36,20 @@ MODEL: ANCHOR_GENERATOR_CONFIG: [ { - 'class_name': 'Car', + 'class_name': 'Vehicle', 'anchor_sizes': [[3.9, 1.6, 1.56]], 'anchor_rotations': [0, 1.57], - 'anchor_bottom_heights': [-0.1], + 'anchor_bottom_heights': [0], 'align_center': False, 'feature_map_stride': 8, - 'matched_threshold': 0.6, - 'unmatched_threshold': 0.45 + 'matched_threshold': 0.55, + 'unmatched_threshold': 0.4 }, { 'class_name': 'Pedestrian', - 'anchor_sizes': [[0.05, 0.03, 0.1]], + 'anchor_sizes': [[0.8, 0.6, 1.73]], 'anchor_rotations': [0, 1.57], - 'anchor_bottom_heights': [-0.03], + 'anchor_bottom_heights': [0], 'align_center': False, 'feature_map_stride': 8, 'matched_threshold': 0.5, @@ -83,9 +57,9 @@ MODEL: }, { 'class_name': 'Cyclist', - 'anchor_sizes': [[0.1, 0.03, 0.1]], + 'anchor_sizes': [[1.76, 0.6, 1.73]], 'anchor_rotations': [0, 1.57], - 'anchor_bottom_heights': [-0.03], + 'anchor_bottom_heights': [0], 'align_center': False, 'feature_map_stride': 8, 'matched_threshold': 0.5, @@ -112,11 +86,11 @@ MODEL: PFE: NAME: VoxelSetAbstraction POINT_SOURCE: raw_points - NUM_KEYPOINTS: 2048 + NUM_KEYPOINTS: 4096 NUM_OUTPUT_FEATURES: 128 SAMPLE_METHOD: FPS - FEATURES_SOURCE: ['bev', 'x_conv1', 'x_conv2', 'x_conv3', 'x_conv4', 'raw_points'] + FEATURES_SOURCE: ['bev', 'x_conv3', 'x_conv4', 'raw_points'] SA_LAYER: raw_points: MLPS: [[16, 16], [16, 16]] @@ -175,9 +149,13 @@ MODEL: TEST: NMS_TYPE: nms_gpu MULTI_CLASSES_NMS: False - NMS_PRE_MAXSIZE: 1024 - NMS_POST_MAXSIZE: 100 - NMS_THRESH: 0.7 +# NMS_PRE_MAXSIZE: 1024 +# NMS_POST_MAXSIZE: 100 +# NMS_THRESH: 0.7 + NMS_PRE_MAXSIZE: 4096 + NMS_POST_MAXSIZE: 300 + NMS_THRESH: 0.85 + ROI_GRID_POOL: GRID_SIZE: 6 diff --git a/tools/cfgs/dataset_configs/custom_dataset.yaml b/tools/cfgs/dataset_configs/custom_dataset.yaml index 040dacf82..f8fedc9f3 100644 --- a/tools/cfgs/dataset_configs/custom_dataset.yaml +++ b/tools/cfgs/dataset_configs/custom_dataset.yaml @@ -1,9 +1,13 @@ DATASET: 'CustomDataset' DATA_PATH: '../data/custom' -# If this config file is modified then pcdet/models/detectors/detector3d_template.py: -# Detector3DTemplate::build_networks:model_info_dict needs to be modified. -POINT_CLOUD_RANGE: [-70.4, -40, -3, 70.4, 40, 1] # x=[-70.4, 70.4], y=[-40,40], z=[-3,1] +POINT_CLOUD_RANGE: [-75.2, -75.2, -2, 75.2, 75.2, 4] + +MAP_CLASS_TO_KITTI: { + 'Vehicle': 'Car', + 'Pedestrian': 'Pedestrian', + 'Cyclist': 'Cyclist', +} DATA_SPLIT: { 'train': train, @@ -15,16 +19,12 @@ INFO_PATH: { 'test': [custom_infos_val.pkl], } -GET_ITEM_LIST: ["points"] -FOV_POINTS_ONLY: True - POINT_FEATURE_ENCODING: { encoding_type: absolute_coordinates_encoding, used_feature_list: ['x', 'y', 'z', 'intensity'], src_feature_list: ['x', 'y', 'z', 'intensity'], } -# Same to pv_rcnn[DATA_AUGMENTOR] DATA_AUGMENTOR: DISABLE_AUG_LIST: ['placeholder'] AUG_CONFIG_LIST: @@ -33,18 +33,17 @@ DATA_AUGMENTOR: DB_INFO_PATH: - custom_dbinfos_train.pkl PREPARE: { - filter_by_min_points: ['Car:5', 'Pedestrian:5', 'Cyclist:5'], - filter_by_difficulty: [-1], + filter_by_min_points: ['Vehicle:5', 'Pedestrian:5', 'Cyclist:5'], } - SAMPLE_GROUPS: ['Car:20','Pedestrian:15', 'Cyclist:15'] + SAMPLE_GROUPS: ['Vehicle:20', 'Pedestrian:15', 'Cyclist:15'] NUM_POINT_FEATURES: 4 DATABASE_WITH_FAKELIDAR: False REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0] LIMIT_WHOLE_SCENE: True - NAME: random_world_flip - ALONG_AXIS_LIST: ['x'] + ALONG_AXIS_LIST: ['x', 'y'] - NAME: random_world_rotation WORLD_ROT_ANGLE: [-0.78539816, 0.78539816] @@ -63,9 +62,9 @@ DATA_PROCESSOR: } - NAME: transform_points_to_voxels - VOXEL_SIZE: [0.05, 0.05, 0.1] + VOXEL_SIZE: [0.1, 0.1, 0.15] MAX_POINTS_PER_VOXEL: 5 MAX_NUMBER_OF_VOXELS: { - 'train': 16000, - 'test': 40000 + 'train': 150000, + 'test': 150000 } \ No newline at end of file From 3cf5763fdd69d3e40884103c1a86ec05215ea9b8 Mon Sep 17 00:00:00 2001 From: jihanyang Date: Mon, 22 Aug 2022 15:46:54 +0800 Subject: [PATCH 4/5] add another custom dataset model config --- tools/cfgs/custom_models/second.yaml | 121 +++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 tools/cfgs/custom_models/second.yaml diff --git a/tools/cfgs/custom_models/second.yaml b/tools/cfgs/custom_models/second.yaml new file mode 100644 index 000000000..e7652fb5a --- /dev/null +++ b/tools/cfgs/custom_models/second.yaml @@ -0,0 +1,121 @@ +CLASS_NAMES: ['Vehicle', 'Pedestrian', 'Cyclist'] + +DATA_CONFIG: + _BASE_CONFIG_: cfgs/dataset_configs/custom_dataset.yaml + + +MODEL: + NAME: SECONDNet + + VFE: + NAME: MeanVFE + + BACKBONE_3D: + NAME: VoxelBackBone8x + + MAP_TO_BEV: + NAME: HeightCompression + NUM_BEV_FEATURES: 256 + + BACKBONE_2D: + NAME: BaseBEVBackbone + + LAYER_NUMS: [5, 5] + LAYER_STRIDES: [1, 2] + NUM_FILTERS: [128, 256] + UPSAMPLE_STRIDES: [1, 2] + NUM_UPSAMPLE_FILTERS: [256, 256] + + DENSE_HEAD: + NAME: AnchorHeadSingle + CLASS_AGNOSTIC: False + + USE_DIRECTION_CLASSIFIER: True + DIR_OFFSET: 0.78539 + DIR_LIMIT_OFFSET: 0.0 + NUM_DIR_BINS: 2 + + ANCHOR_GENERATOR_CONFIG: [ + { + 'class_name': 'Vehicle', + 'anchor_sizes': [[3.9, 1.6, 1.56]], + 'anchor_rotations': [0, 1.57], + 'anchor_bottom_heights': [0], + 'align_center': False, + 'feature_map_stride': 8, + 'matched_threshold': 0.55, + 'unmatched_threshold': 0.4 + }, + { + 'class_name': 'Pedestrian', + 'anchor_sizes': [[0.8, 0.6, 1.73]], + 'anchor_rotations': [0, 1.57], + 'anchor_bottom_heights': [0], + 'align_center': False, + 'feature_map_stride': 8, + 'matched_threshold': 0.5, + 'unmatched_threshold': 0.35 + }, + { + 'class_name': 'Cyclist', + 'anchor_sizes': [[1.76, 0.6, 1.73]], + 'anchor_rotations': [0, 1.57], + 'anchor_bottom_heights': [0], + 'align_center': False, + 'feature_map_stride': 8, + 'matched_threshold': 0.5, + 'unmatched_threshold': 0.35 + } + ] + + TARGET_ASSIGNER_CONFIG: + NAME: AxisAlignedTargetAssigner + POS_FRACTION: -1.0 + SAMPLE_SIZE: 512 + NORM_BY_NUM_EXAMPLES: False + MATCH_HEIGHT: False + BOX_CODER: ResidualCoder + + LOSS_CONFIG: + LOSS_WEIGHTS: { + 'cls_weight': 1.0, + 'loc_weight': 2.0, + 'dir_weight': 0.2, + 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + } + + POST_PROCESSING: + RECALL_THRESH_LIST: [0.3, 0.5, 0.7] + SCORE_THRESH: 0.1 + OUTPUT_RAW_SCORE: False + + EVAL_METRIC: kitti + + NMS_CONFIG: + MULTI_CLASSES_NMS: False + NMS_TYPE: nms_gpu + NMS_THRESH: 0.85 + NMS_PRE_MAXSIZE: 4096 + NMS_POST_MAXSIZE: 500 + + +OPTIMIZATION: + BATCH_SIZE_PER_GPU: 4 + NUM_EPOCHS: 80 + + OPTIMIZER: adam_onecycle + LR: 0.003 + WEIGHT_DECAY: 0.01 + MOMENTUM: 0.9 + + MOMS: [0.95, 0.85] + PCT_START: 0.4 + DIV_FACTOR: 10 + DECAY_STEP_LIST: [35, 45] + LR_DECAY: 0.1 + LR_CLIP: 0.0000001 + + LR_WARMUP: False + WARMUP_EPOCH: 1 + + GRAD_NORM_CLIP: 10 \ No newline at end of file From e8bcd6897d780dc4fdb9da894dad07155529b702 Mon Sep 17 00:00:00 2001 From: jihanyang Date: Mon, 22 Aug 2022 19:25:05 +0800 Subject: [PATCH 5/5] Use category name in label format; Update README --- README.md | 2 ++ docs/CUSTOM_DATASET_TUTORIAL.md | 8 +++----- pcdet/datasets/custom/custom_dataset.py | 21 ++++++++++++--------- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index dee155a24..4f1d5489f 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,8 @@ It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/18 ## Changelog +[2022-07-05] Added support for [custom dataset tutorial and template](docs/CUSTOM_DATASET_TUTORIAL.md) + [2022-07-05] Added support for the 3D object detection backbone network [`Focals Conv`](https://openaccess.thecvf.com/content/CVPR2022/papers/Chen_Focal_Sparse_Convolutional_Networks_for_3D_Object_Detection_CVPR_2022_paper.pdf). [2022-02-12] Added support for using docker. Please refer to the guidance in [./docker](./docker). diff --git a/docs/CUSTOM_DATASET_TUTORIAL.md b/docs/CUSTOM_DATASET_TUTORIAL.md index 0bf0f0c91..02cba9290 100644 --- a/docs/CUSTOM_DATASET_TUTORIAL.md +++ b/docs/CUSTOM_DATASET_TUTORIAL.md @@ -6,13 +6,11 @@ their corresponding annotations. Point clouds are supposed to be stored in `.npy We only consider the most basic information -- category and bounding box in the label template. Annotations are stored in the `.txt`. Each line represents a box in a given scene as below: ``` -[x y z dx dy dz heading_angle category_id] -1.50 1.46 0.10 5.12 1.85 4.13 1.56 0 -5.54 0.57 0.41 1.08 0.74 1.95 1.57 1 +[x y z dx dy dz heading_angle category_name] +1.50 1.46 0.10 5.12 1.85 4.13 1.56 Vehicle +5.54 0.57 0.41 1.08 0.74 1.95 1.57 Pedestrian ``` The box should in the unified 3D box definition (see [README](../README.md)) -The correspondence between `category_id` and `category_name` need to be pre-defined. - ## Files structure Files should be placed as the following folder structure: diff --git a/pcdet/datasets/custom/custom_dataset.py b/pcdet/datasets/custom/custom_dataset.py index 9552074aa..f6283fb0c 100644 --- a/pcdet/datasets/custom/custom_dataset.py +++ b/pcdet/datasets/custom/custom_dataset.py @@ -53,8 +53,14 @@ def get_label(self, idx): lines = f.readlines() # [N, 8]: (x y z dx dy dz heading_angle category_id) - gt_boxes = [line.strip().split(' ') for line in lines] - return np.array(gt_boxes, dtype=np.float32) + gt_boxes = [] + gt_names = [] + for line in lines: + line_list = line.strip().split(' ') + gt_boxes.append(line_list[:-1]) + gt_names.append(line_list[-1]) + + return np.array(gt_boxes, dtype=np.float32), np.array(gt_names) def get_lidar(self, idx): lidar_file = self.root_path / 'points' / ('%s.npy' % idx) @@ -136,8 +142,6 @@ def kitti_eval(eval_det_annos, eval_gt_annos, map_name_to_kitti): def get_infos(self, class_names, num_workers=4, has_label=True, sample_id_list=None, num_features=4): import concurrent.futures as futures - class_names = np.array(class_names) - def process_single_scene(sample_idx): print('%s sample_idx: %s' % (self.split, sample_idx)) info = {} @@ -146,8 +150,8 @@ def process_single_scene(sample_idx): if has_label: annotations = {} - gt_boxes_lidar = self.get_label(sample_idx) - annotations['name'] = class_names[gt_boxes_lidar[:, -1].astype(np.int64)] + gt_boxes_lidar, name = self.get_label(sample_idx) + annotations['name'] = name annotations['gt_boxes_lidar'] = gt_boxes_lidar[:, :7] info['annos'] = annotations @@ -219,10 +223,9 @@ def create_label_file_with_name_and_box(class_names, gt_names, gt_boxes, save_la name = gt_names[idx] if name not in class_names: continue - category_id = class_names.index(name) - line = "{x} {y} {z} {l} {w} {h} {angle} {category_id}\n".format( + line = "{x} {y} {z} {l} {w} {h} {angle} {name}\n".format( x=boxes[0], y=boxes[1], z=(boxes[2]), l=boxes[3], - w=boxes[4], h=boxes[5], angle=boxes[6], category_id=category_id + w=boxes[4], h=boxes[5], angle=boxes[6], name=name ) f.write(line)