From a5627bfb4bee1db939e2820e9c401dbd52ca8c13 Mon Sep 17 00:00:00 2001 From: Xiang Xu Date: Tue, 28 Feb 2023 20:13:22 +0800 Subject: [PATCH] [Feature] Support `LaserMix` augmentation (#2302) * add lasermix * add prob * update description * update --- mmdet3d/datasets/seg3d_dataset.py | 2 +- mmdet3d/datasets/transforms/__init__.py | 6 +- mmdet3d/datasets/transforms/transforms_3d.py | 145 ++++++++++++++++++ .../test_transforms/test_transforms_3d.py | 127 ++++++++++++++- 4 files changed, 275 insertions(+), 5 deletions(-) diff --git a/mmdet3d/datasets/seg3d_dataset.py b/mmdet3d/datasets/seg3d_dataset.py index 42025dee49..e2bb74b91a 100644 --- a/mmdet3d/datasets/seg3d_dataset.py +++ b/mmdet3d/datasets/seg3d_dataset.py @@ -295,7 +295,7 @@ def prepare_data(self, idx: int) -> dict: if not self.test_mode: data_info = self.get_data_info(idx) # Pass the dataset to the pipeline during training to support mixed - # data augmentation, such as polarmix. + # data augmentation, such as polarmix and lasermix. data_info['dataset'] = self return self.pipeline(data_info) else: diff --git a/mmdet3d/datasets/transforms/__init__.py b/mmdet3d/datasets/transforms/__init__.py index c8969f8b60..09fbed659b 100644 --- a/mmdet3d/datasets/transforms/__init__.py +++ b/mmdet3d/datasets/transforms/__init__.py @@ -11,8 +11,8 @@ from .transforms_3d import (AffineResize, BackgroundPointsFilter, GlobalAlignment, GlobalRotScaleTrans, IndoorPatchPointSample, IndoorPointSample, - MultiViewWrapper, ObjectNameFilter, ObjectNoise, - ObjectRangeFilter, ObjectSample, + LaserMix, MultiViewWrapper, ObjectNameFilter, + ObjectNoise, ObjectRangeFilter, ObjectSample, PhotoMetricDistortion3D, PointSample, PointShuffle, PointsRangeFilter, PolarMix, RandomDropPointsColor, RandomFlip3D, RandomJitterPoints, RandomResize3D, @@ -30,5 +30,5 @@ 'RandomDropPointsColor', 'RandomJitterPoints', 'AffineResize', 'RandomShiftScale', 'LoadPointsFromDict', 'Resize3D', 'RandomResize3D', 'MultiViewWrapper', 'PhotoMetricDistortion3D', 'MonoDet3DInferencerLoader', - 'LidarDet3DInferencerLoader', 'PolarMix' + 'LidarDet3DInferencerLoader', 'PolarMix', 'LaserMix' ] diff --git a/mmdet3d/datasets/transforms/transforms_3d.py b/mmdet3d/datasets/transforms/transforms_3d.py index dbdbf2a45c..ed5741b9df 100644 --- a/mmdet3d/datasets/transforms/transforms_3d.py +++ b/mmdet3d/datasets/transforms/transforms_3d.py @@ -2521,3 +2521,148 @@ def __repr__(self) -> str: repr_str += f'pre_transform={self.pre_transform}, ' repr_str += f'prob={self.prob})' return repr_str + + +@TRANSFORMS.register_module() +class LaserMix(BaseTransform): + """LaserMix data augmentation. + + The lasermix transform steps are as follows: + + 1. Another random point cloud is picked by dataset. + 2. Divide the point cloud into several regions according to pitch + angles and combine the areas crossly. + + Required Keys: + + - points (:obj:`BasePoints`) + - pts_semantic_mask (np.int64) + - dataset (:obj:`BaseDataset`) + + Modified Keys: + + - points (:obj:`BasePoints`) + - pts_semantic_mask (np.int64) + + Args: + num_areas (List[int]): A list of area numbers will be divided into. + pitch_angles (Sequence[float]): Pitch angles used to divide areas. + pre_transform (Sequence[dict], optional): Sequence of transform object + or config dict to be composed. Defaults to None. + prob (float): The transformation probability. Defaults to 1.0. + """ + + def __init__(self, + num_areas: List[int], + pitch_angles: Sequence[float], + pre_transform: Optional[Sequence[dict]] = None, + prob: float = 1.0) -> None: + assert is_list_of(num_areas, int), \ + 'num_areas should be a list of int.' + self.num_areas = num_areas + + assert len(pitch_angles) == 2, \ + 'The length of pitch_angles should be 2, ' \ + f'but got {len(pitch_angles)}.' + assert pitch_angles[1] > pitch_angles[0], \ + 'pitch_angles[1] should be larger than pitch_angles[0].' + self.pitch_angles = pitch_angles + + self.prob = prob + if pre_transform is None: + self.pre_transform = None + else: + self.pre_transform = Compose(pre_transform) + + def laser_mix_transform(self, input_dict: dict, mix_results: dict) -> dict: + """LaserMix transform function. + + Args: + input_dict (dict): Result dict from loading pipeline. + mix_results (dict): Mixed dict picked from dataset. + + Returns: + dict: output dict after transformation. + """ + mix_points = mix_results['points'] + mix_pts_semantic_mask = mix_results['pts_semantic_mask'] + + points = input_dict['points'] + pts_semantic_mask = input_dict['pts_semantic_mask'] + + rho = torch.sqrt(points.coord[:, 0]**2 + points.coord[:, 1]**2) + pitch = torch.atan2(points.coord[:, 2], rho) + pitch = torch.clip(pitch, self.pitch_angles[0] + 1e-5, + self.pitch_angles[1] - 1e-5) + + mix_rho = torch.sqrt(mix_points.coord[:, 0]**2 + + mix_points.coord[:, 1]**2) + mix_pitch = torch.atan2(mix_points.coord[:, 2], mix_rho) + mix_pitch = torch.clip(mix_pitch, self.pitch_angles[0] + 1e-5, + self.pitch_angles[1] - 1e-5) + + num_areas = np.random.choice(self.num_areas, size=1)[0] + angle_list = np.linspace(self.pitch_angles[1], self.pitch_angles[0], + num_areas + 1) + out_points = [] + out_pts_semantic_mask = [] + for i in range(num_areas): + # convert angle to radian + start_angle = angle_list[i + 1] / 180 * np.pi + end_angle = angle_list[i] / 180 * np.pi + if i % 2 == 0: # pick from original point cloud + idx = (pitch > start_angle) & (pitch <= end_angle) + out_points.append(points[idx]) + out_pts_semantic_mask.append(pts_semantic_mask[idx.numpy()]) + else: # pickle from mixed point cloud + idx = (mix_pitch > start_angle) & (mix_pitch <= end_angle) + out_points.append(mix_points[idx]) + out_pts_semantic_mask.append( + mix_pts_semantic_mask[idx.numpy()]) + out_points = points.cat(out_points) + out_pts_semantic_mask = np.concatenate(out_pts_semantic_mask, axis=0) + input_dict['points'] = out_points + input_dict['pts_semantic_mask'] = out_pts_semantic_mask + return input_dict + + def transform(self, input_dict: dict) -> dict: + """LaserMix transform function. + + Args: + input_dict (dict): Result dict from loading pipeline. + + Returns: + dict: output dict after transformation. + """ + if np.random.rand() > self.prob: + return input_dict + + assert 'dataset' in input_dict, \ + '`dataset` is needed to pass through LaserMix, while not found.' + dataset = input_dict['dataset'] + + # get index of other point cloud + index = np.random.randint(0, len(dataset)) + + mix_results = dataset.get_data_info(index) + + if self.pre_transform is not None: + # pre_transform may also require dataset + mix_results.update({'dataset': dataset}) + # before lasermix need to go through + # the necessary pre_transform + mix_results = self.pre_transform(mix_results) + mix_results.pop('dataset') + + input_dict = self.laser_mix_transform(input_dict, mix_results) + + return input_dict + + def __repr__(self) -> str: + """str: Return a string that describes the module.""" + repr_str = self.__class__.__name__ + repr_str += f'(num_areas={self.num_areas}, ' + repr_str += f'pitch_angles={self.pitch_angles}, ' + repr_str += f'pre_transform={self.pre_transform}, ' + repr_str += f'prob={self.prob})' + return repr_str diff --git a/tests/test_datasets/test_transforms/test_transforms_3d.py b/tests/test_datasets/test_transforms/test_transforms_3d.py index 036d8bbfca..3d6fd6eac2 100644 --- a/tests/test_datasets/test_transforms/test_transforms_3d.py +++ b/tests/test_datasets/test_transforms/test_transforms_3d.py @@ -8,7 +8,7 @@ from mmdet3d.datasets import (GlobalAlignment, RandomFlip3D, SemanticKITTIDataset) -from mmdet3d.datasets.transforms import GlobalRotScaleTrans, PolarMix +from mmdet3d.datasets.transforms import GlobalRotScaleTrans, LaserMix, PolarMix from mmdet3d.structures import LiDARPoints from mmdet3d.testing import create_data_info_after_loading from mmdet3d.utils import register_all_modules @@ -222,3 +222,128 @@ def test_transform(self): results = transform.transform(copy.deepcopy(self.results)) self.assertTrue(results['points'].shape[0] == results['pts_semantic_mask'].shape[0]) + + +class TestLaserMix(unittest.TestCase): + + def setUp(self): + self.pre_transform = [ + dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=4, + use_dim=4), + dict( + type='LoadAnnotations3D', + with_bbox_3d=False, + with_label_3d=False, + with_mask_3d=False, + with_seg_3d=True, + seg_3d_dtype='np.int32'), + dict(type='PointSegClassMapping'), + ] + classes = ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck', 'bus', + 'person', 'bicyclist', 'motorcyclist', 'road', 'parking', + 'sidewalk', 'other-ground', 'building', 'fence', + 'vegetation', 'trunck', 'terrian', 'pole', 'traffic-sign') + palette = [ + [174, 199, 232], + [152, 223, 138], + [31, 119, 180], + [255, 187, 120], + [188, 189, 34], + [140, 86, 75], + [255, 152, 150], + [214, 39, 40], + [197, 176, 213], + [148, 103, 189], + [196, 156, 148], + [23, 190, 207], + [247, 182, 210], + [219, 219, 141], + [255, 127, 14], + [158, 218, 229], + [44, 160, 44], + [112, 128, 144], + [227, 119, 194], + [82, 84, 163], + ] + seg_label_mapping = { + 0: 0, # "unlabeled" + 1: 0, # "outlier" mapped to "unlabeled" --------------mapped + 10: 1, # "car" + 11: 2, # "bicycle" + 13: 5, # "bus" mapped to "other-vehicle" --------------mapped + 15: 3, # "motorcycle" + 16: 5, # "on-rails" mapped to "other-vehicle" ---------mapped + 18: 4, # "truck" + 20: 5, # "other-vehicle" + 30: 6, # "person" + 31: 7, # "bicyclist" + 32: 8, # "motorcyclist" + 40: 9, # "road" + 44: 10, # "parking" + 48: 11, # "sidewalk" + 49: 12, # "other-ground" + 50: 13, # "building" + 51: 14, # "fence" + 52: 0, # "other-structure" mapped to "unlabeled" ------mapped + 60: 9, # "lane-marking" to "road" ---------------------mapped + 70: 15, # "vegetation" + 71: 16, # "trunk" + 72: 17, # "terrain" + 80: 18, # "pole" + 81: 19, # "traffic-sign" + 99: 0, # "other-object" to "unlabeled" ----------------mapped + 252: 1, # "moving-car" to "car" ------------------------mapped + 253: 7, # "moving-bicyclist" to "bicyclist" ------------mapped + 254: 6, # "moving-person" to "person" ------------------mapped + 255: 8, # "moving-motorcyclist" to "motorcyclist" ------mapped + 256: 5, # "moving-on-rails" mapped to "other-vehic------mapped + 257: 5, # "moving-bus" mapped to "other-vehicle" -------mapped + 258: 4, # "moving-truck" to "truck" --------------------mapped + 259: 5 # "moving-other"-vehicle to "other-vehicle"-----mapped + } + max_label = 259 + self.dataset = SemanticKITTIDataset( + './tests/data/semantickitti/', + 'semantickitti_infos.pkl', + metainfo=dict( + classes=classes, + palette=palette, + seg_label_mapping=seg_label_mapping, + max_label=max_label), + data_prefix=dict( + pts='sequences/00/velodyne', + pts_semantic_mask='sequences/00/labels'), + pipeline=[], + modality=dict(use_lidar=True, use_camera=False)) + points = np.random.random((100, 4)) + self.results = { + 'points': LiDARPoints(points, points_dim=4), + 'pts_semantic_mask': np.random.randint(0, 20, (100, )), + 'dataset': self.dataset + } + + def test_transform(self): + # test assertion for invalid num_areas + with self.assertRaises(AssertionError): + transform = LaserMix(num_areas=3, pitch_angles=[-20, 0]) + + with self.assertRaises(AssertionError): + transform = LaserMix(num_areas=[3.0, 4.0], pitch_angles=[-20, 0]) + + # test assertion for invalid pitch_angles + with self.assertRaises(AssertionError): + transform = LaserMix(num_areas=[3, 4], pitch_angles=[-20]) + + with self.assertRaises(AssertionError): + transform = LaserMix(num_areas=[3, 4], pitch_angles=[0, -20]) + + transform = LaserMix( + num_areas=[3, 4, 5, 6], + pitch_angles=[-20, 0], + pre_transform=self.pre_transform) + results = transform.transform(copy.deepcopy(self.results)) + self.assertTrue(results['points'].shape[0] == + results['pts_semantic_mask'].shape[0])