Skip to content

Commit c5dfdd7

Browse files
committed
Add multi-modal support for Nuscenes dataset
1 parent 4dc1849 commit c5dfdd7

File tree

9 files changed

+330
-11
lines changed

9 files changed

+330
-11
lines changed

pcdet/datasets/augmentor/data_augmentor.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from functools import partial
22

33
import numpy as np
4+
from PIL import Image
45

56
from ...utils import common_utils
67
from . import augmentor_utils, database_sampler
@@ -23,6 +24,18 @@ def __init__(self, root_path, augmentor_configs, class_names, logger=None):
2324
cur_augmentor = getattr(self, cur_cfg.NAME)(config=cur_cfg)
2425
self.data_augmentor_queue.append(cur_augmentor)
2526

27+
def disableAugmentation(self, augmentor_configs):
28+
self.data_augmentor_queue = []
29+
aug_config_list = augmentor_configs if isinstance(augmentor_configs, list) \
30+
else augmentor_configs.AUG_CONFIG_LIST
31+
32+
for cur_cfg in aug_config_list:
33+
if not isinstance(augmentor_configs, list):
34+
if cur_cfg.NAME in augmentor_configs.DISABLE_AUG_LIST:
35+
continue
36+
cur_augmentor = getattr(self, cur_cfg.NAME)(config=cur_cfg)
37+
self.data_augmentor_queue.append(cur_augmentor)
38+
2639
def gt_sampling(self, config=None):
2740
db_sampler = database_sampler.DataBaseSampler(
2841
root_path=self.root_path,
@@ -139,6 +152,7 @@ def random_world_translation(self, data_dict=None, config=None):
139152

140153
data_dict['gt_boxes'] = gt_boxes
141154
data_dict['points'] = points
155+
data_dict['noise_translate'] = noise_translate
142156
return data_dict
143157

144158
def random_local_translation(self, data_dict=None, config=None):
@@ -251,6 +265,28 @@ def random_local_pyramid_aug(self, data_dict=None, config=None):
251265
data_dict['points'] = points
252266
return data_dict
253267

268+
def imgaug(self, data_dict=None, config=None):
269+
if data_dict is None:
270+
return partial(self.imgaug, config=config)
271+
imgs = data_dict["camera_imgs"]
272+
img_process_infos = data_dict['img_process_infos']
273+
new_imgs = []
274+
for img, img_process_info in zip(imgs, img_process_infos):
275+
flip = False
276+
if config.RAND_FLIP and np.random.choice([0, 1]):
277+
flip = True
278+
rotate = np.random.uniform(*config.ROT_LIM)
279+
# aug images
280+
if flip:
281+
img = img.transpose(method=Image.FLIP_LEFT_RIGHT)
282+
img = img.rotate(rotate)
283+
img_process_info[2] = flip
284+
img_process_info[3] = rotate
285+
new_imgs.append(img)
286+
287+
data_dict["camera_imgs"] = new_imgs
288+
return data_dict
289+
254290
def forward(self, data_dict):
255291
"""
256292
Args:

pcdet/datasets/dataset.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pathlib import Path
33

44
import numpy as np
5+
import torch
56
import torch.utils.data as torch_data
67

78
from ..utils import common_utils
@@ -130,6 +131,30 @@ def __getitem__(self, index):
130131
"""
131132
raise NotImplementedError
132133

134+
def set_lidar_aug_matrix(self, data_dict):
135+
"""
136+
Get lidar augment matrix (4 x 4), which are used to recover orig point coordinates.
137+
"""
138+
lidar_aug_matrix = np.eye(4)
139+
if 'flip_y' in data_dict.keys():
140+
flip_x = data_dict['flip_x']
141+
flip_y = data_dict['flip_y']
142+
if flip_x:
143+
lidar_aug_matrix[:3,:3] = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]]) @ lidar_aug_matrix[:3,:3]
144+
if flip_y:
145+
lidar_aug_matrix[:3,:3] = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) @ lidar_aug_matrix[:3,:3]
146+
if 'noise_rot' in data_dict.keys():
147+
noise_rot = data_dict['noise_rot']
148+
lidar_aug_matrix[:3,:3] = common_utils.angle2matrix(torch.tensor(noise_rot)) @ lidar_aug_matrix[:3,:3]
149+
if 'noise_scale' in data_dict.keys():
150+
noise_scale = data_dict['noise_scale']
151+
lidar_aug_matrix[:3,:3] *= noise_scale
152+
if 'noise_translate' in data_dict.keys():
153+
noise_translate = data_dict['noise_translate']
154+
lidar_aug_matrix[:3,3:4] = noise_translate.T
155+
data_dict['lidar_aug_matrix'] = lidar_aug_matrix
156+
return data_dict
157+
133158
def prepare_data(self, data_dict):
134159
"""
135160
Args:
@@ -165,6 +190,7 @@ def prepare_data(self, data_dict):
165190
)
166191
if 'calib' in data_dict:
167192
data_dict['calib'] = calib
193+
data_dict = self.set_lidar_aug_matrix(data_dict)
168194
if data_dict.get('gt_boxes', None) is not None:
169195
selected = common_utils.keep_arrays_by_name(data_dict['gt_names'], self.class_names)
170196
data_dict['gt_boxes'] = data_dict['gt_boxes'][selected]
@@ -287,6 +313,8 @@ def collate_batch(batch_list, _unused=False):
287313
constant_values=pad_value)
288314
points.append(points_pad)
289315
ret[key] = np.stack(points, axis=0)
316+
elif key in ['camera_imgs']:
317+
ret[key] = torch.stack([torch.stack(imgs,dim=0) for imgs in val],dim=0)
290318
else:
291319
ret[key] = np.stack(val, axis=0)
292320
except:

pcdet/datasets/nuscenes/nuscenes_dataset.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from ...ops.roiaware_pool3d import roiaware_pool3d_utils
99
from ...utils import common_utils
1010
from ..dataset import DatasetTemplate
11+
from pyquaternion import Quaternion
12+
from PIL import Image
1113

1214

1315
class NuScenesDataset(DatasetTemplate):
@@ -17,6 +19,13 @@ def __init__(self, dataset_cfg, class_names, training=True, root_path=None, logg
1719
dataset_cfg=dataset_cfg, class_names=class_names, training=training, root_path=root_path, logger=logger
1820
)
1921
self.infos = []
22+
self.camera_config = self.dataset_cfg.get('CAMERA_CONFIG', None)
23+
if self.camera_config is not None:
24+
self.use_camera = self.camera_config.get('USE_CAMERA', True)
25+
self.camera_image_config = self.camera_config.IMAGE
26+
else:
27+
self.use_camera = False
28+
2029
self.include_nuscenes_data(self.mode)
2130
if self.training and self.dataset_cfg.get('BALANCED_RESAMPLING', False):
2231
self.infos = self.balanced_infos_resampling(self.infos)
@@ -108,6 +117,41 @@ def get_lidar_with_sweeps(self, index, max_sweeps=1):
108117
points = np.concatenate((points, times), axis=1)
109118
return points
110119

120+
def crop_image(self, input_dict):
121+
W, H = input_dict["ori_shape"]
122+
imgs = input_dict["camera_imgs"]
123+
img_process_infos = []
124+
crop_images = []
125+
for img in imgs:
126+
if self.training == True:
127+
fH, fW = self.camera_image_config.FINAL_DIM
128+
resize_lim = self.camera_image_config.RESIZE_LIM_TRAIN
129+
resize = np.random.uniform(*resize_lim)
130+
resize_dims = (int(W * resize), int(H * resize))
131+
newW, newH = resize_dims
132+
crop_h = newH - fH
133+
crop_w = int(np.random.uniform(0, max(0, newW - fW)))
134+
crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
135+
else:
136+
fH, fW = self.camera_image_config.FINAL_DIM
137+
resize_lim = self.camera_image_config.RESIZE_LIM_TEST
138+
resize = np.mean(resize_lim)
139+
resize_dims = (int(W * resize), int(H * resize))
140+
newW, newH = resize_dims
141+
crop_h = newH - fH
142+
crop_w = int(max(0, newW - fW) / 2)
143+
crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
144+
145+
# reisze and crop image
146+
img = img.resize(resize_dims)
147+
img = img.crop(crop)
148+
crop_images.append(img)
149+
img_process_infos.append([resize, crop, False, 0])
150+
151+
input_dict['img_process_infos'] = img_process_infos
152+
input_dict['camera_imgs'] = crop_images
153+
return input_dict
154+
111155
def __len__(self):
112156
if self._merge_all_iters_to_one_epoch:
113157
return len(self.infos) * self.total_epochs
@@ -137,6 +181,60 @@ def __getitem__(self, index):
137181
'gt_names': info['gt_names'] if mask is None else info['gt_names'][mask],
138182
'gt_boxes': info['gt_boxes'] if mask is None else info['gt_boxes'][mask]
139183
})
184+
if self.use_camera:
185+
input_dict["image_paths"] = []
186+
input_dict["lidar2camera"] = []
187+
input_dict["lidar2image"] = []
188+
input_dict["camera2ego"] = []
189+
input_dict["camera_intrinsics"] = []
190+
input_dict["camera2lidar"] = []
191+
192+
for _, camera_info in info["cams"].items():
193+
input_dict["image_paths"].append(camera_info["data_path"])
194+
195+
# lidar to camera transform
196+
lidar2camera_r = np.linalg.inv(camera_info["sensor2lidar_rotation"])
197+
lidar2camera_t = (
198+
camera_info["sensor2lidar_translation"] @ lidar2camera_r.T
199+
)
200+
lidar2camera_rt = np.eye(4).astype(np.float32)
201+
lidar2camera_rt[:3, :3] = lidar2camera_r.T
202+
lidar2camera_rt[3, :3] = -lidar2camera_t
203+
input_dict["lidar2camera"].append(lidar2camera_rt.T)
204+
205+
# camera intrinsics
206+
camera_intrinsics = np.eye(4).astype(np.float32)
207+
camera_intrinsics[:3, :3] = camera_info["camera_intrinsics"]
208+
input_dict["camera_intrinsics"].append(camera_intrinsics)
209+
210+
# lidar to image transform
211+
lidar2image = camera_intrinsics @ lidar2camera_rt.T
212+
input_dict["lidar2image"].append(lidar2image)
213+
214+
# camera to ego transform
215+
camera2ego = np.eye(4).astype(np.float32)
216+
camera2ego[:3, :3] = Quaternion(
217+
camera_info["sensor2ego_rotation"]
218+
).rotation_matrix
219+
camera2ego[:3, 3] = camera_info["sensor2ego_translation"]
220+
input_dict["camera2ego"].append(camera2ego)
221+
222+
# camera to lidar transform
223+
camera2lidar = np.eye(4).astype(np.float32)
224+
camera2lidar[:3, :3] = camera_info["sensor2lidar_rotation"]
225+
camera2lidar[:3, 3] = camera_info["sensor2lidar_translation"]
226+
input_dict["camera2lidar"].append(camera2lidar)
227+
# read image
228+
filename = input_dict["image_paths"]
229+
images = []
230+
for name in filename:
231+
images.append(Image.open(str(self.root_path / name)))
232+
233+
input_dict["camera_imgs"] = images
234+
input_dict["ori_shape"] = images[0].size
235+
236+
# resize and crop image
237+
input_dict = self.crop_image(input_dict)
140238

141239
data_dict = self.prepare_data(data_dict=input_dict)
142240

@@ -251,7 +349,7 @@ def create_groundtruth_database(self, used_classes=None, max_sweeps=10):
251349
pickle.dump(all_db_infos, f)
252350

253351

254-
def create_nuscenes_info(version, data_path, save_path, max_sweeps=10):
352+
def create_nuscenes_info(version, data_path, save_path, max_sweeps=10, with_cam=False):
255353
from nuscenes.nuscenes import NuScenes
256354
from nuscenes.utils import splits
257355
from . import nuscenes_utils
@@ -308,6 +406,7 @@ def create_nuscenes_info(version, data_path, save_path, max_sweeps=10):
308406
parser.add_argument('--cfg_file', type=str, default=None, help='specify the config of dataset')
309407
parser.add_argument('--func', type=str, default='create_nuscenes_infos', help='')
310408
parser.add_argument('--version', type=str, default='v1.0-trainval', help='')
409+
parser.add_argument('--with_cam', action='store_true', default=False, help='use camera or not')
311410
args = parser.parse_args()
312411

313412
if args.func == 'create_nuscenes_infos':
@@ -319,6 +418,7 @@ def create_nuscenes_info(version, data_path, save_path, max_sweeps=10):
319418
data_path=ROOT_DIR / 'data' / 'nuscenes',
320419
save_path=ROOT_DIR / 'data' / 'nuscenes',
321420
max_sweeps=dataset_cfg.MAX_SWEEPS,
421+
with_cam=args.with_cam
322422
)
323423

324424
nuscenes_dataset = NuScenesDataset(

pcdet/datasets/nuscenes/nuscenes_utils.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,69 @@ def quaternion_yaw(q: Quaternion) -> float:
247247
yaw = np.arctan2(v[1], v[0])
248248

249249
return yaw
250+
250251

252+
def obtain_sensor2top(
253+
nusc, sensor_token, l2e_t, l2e_r_mat, e2g_t, e2g_r_mat, sensor_type="lidar"
254+
):
255+
"""Obtain the info with RT matric from general sensor to Top LiDAR.
251256
252-
def fill_trainval_infos(data_path, nusc, train_scenes, val_scenes, test=False, max_sweeps=10):
257+
Args:
258+
nusc (class): Dataset class in the nuScenes dataset.
259+
sensor_token (str): Sample data token corresponding to the
260+
specific sensor type.
261+
l2e_t (np.ndarray): Translation from lidar to ego in shape (1, 3).
262+
l2e_r_mat (np.ndarray): Rotation matrix from lidar to ego
263+
in shape (3, 3).
264+
e2g_t (np.ndarray): Translation from ego to global in shape (1, 3).
265+
e2g_r_mat (np.ndarray): Rotation matrix from ego to global
266+
in shape (3, 3).
267+
sensor_type (str): Sensor to calibrate. Default: 'lidar'.
268+
269+
Returns:
270+
sweep (dict): Sweep information after transformation.
271+
"""
272+
sd_rec = nusc.get("sample_data", sensor_token)
273+
cs_record = nusc.get("calibrated_sensor", sd_rec["calibrated_sensor_token"])
274+
pose_record = nusc.get("ego_pose", sd_rec["ego_pose_token"])
275+
data_path = str(nusc.get_sample_data_path(sd_rec["token"]))
276+
# if os.getcwd() in data_path: # path from lyftdataset is absolute path
277+
# data_path = data_path.split(f"{os.getcwd()}/")[-1] # relative path
278+
sweep = {
279+
"data_path": data_path,
280+
"type": sensor_type,
281+
"sample_data_token": sd_rec["token"],
282+
"sensor2ego_translation": cs_record["translation"],
283+
"sensor2ego_rotation": cs_record["rotation"],
284+
"ego2global_translation": pose_record["translation"],
285+
"ego2global_rotation": pose_record["rotation"],
286+
"timestamp": sd_rec["timestamp"],
287+
}
288+
l2e_r_s = sweep["sensor2ego_rotation"]
289+
l2e_t_s = sweep["sensor2ego_translation"]
290+
e2g_r_s = sweep["ego2global_rotation"]
291+
e2g_t_s = sweep["ego2global_translation"]
292+
293+
# obtain the RT from sensor to Top LiDAR
294+
# sweep->ego->global->ego'->lidar
295+
l2e_r_s_mat = Quaternion(l2e_r_s).rotation_matrix
296+
e2g_r_s_mat = Quaternion(e2g_r_s).rotation_matrix
297+
R = (l2e_r_s_mat.T @ e2g_r_s_mat.T) @ (
298+
np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T
299+
)
300+
T = (l2e_t_s @ e2g_r_s_mat.T + e2g_t_s) @ (
301+
np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T
302+
)
303+
T -= (
304+
e2g_t @ (np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T)
305+
+ l2e_t @ np.linalg.inv(l2e_r_mat).T
306+
).squeeze(0)
307+
sweep["sensor2lidar_rotation"] = R.T # points @ R.T + T
308+
sweep["sensor2lidar_translation"] = T
309+
return sweep
310+
311+
312+
def fill_trainval_infos(data_path, nusc, train_scenes, val_scenes, test=False, max_sweeps=10, with_cam=False):
253313
train_nusc_infos = []
254314
val_nusc_infos = []
255315
progress_bar = tqdm.tqdm(total=len(nusc.sample), desc='create_info', dynamic_ncols=True)
@@ -291,6 +351,34 @@ def fill_trainval_infos(data_path, nusc, train_scenes, val_scenes, test=False, m
291351
'car_from_global': car_from_global,
292352
'timestamp': ref_time,
293353
}
354+
if with_cam:
355+
info['cams'] = dict()
356+
l2e_r = ref_cs_rec["rotation"]
357+
l2e_t = ref_cs_rec["translation"],
358+
e2g_r = ref_pose_rec["rotation"]
359+
e2g_t = ref_pose_rec["translation"]
360+
l2e_r_mat = Quaternion(l2e_r).rotation_matrix
361+
e2g_r_mat = Quaternion(e2g_r).rotation_matrix
362+
363+
# obtain 6 image's information per frame
364+
camera_types = [
365+
"CAM_FRONT",
366+
"CAM_FRONT_RIGHT",
367+
"CAM_FRONT_LEFT",
368+
"CAM_BACK",
369+
"CAM_BACK_LEFT",
370+
"CAM_BACK_RIGHT",
371+
]
372+
for cam in camera_types:
373+
cam_token = sample["data"][cam]
374+
cam_path, _, camera_intrinsics = nusc.get_sample_data(cam_token)
375+
cam_info = obtain_sensor2top(
376+
nusc, cam_token, l2e_t, l2e_r_mat, e2g_t, e2g_r_mat, cam
377+
)
378+
cam_info['data_path'] = Path(cam_info['data_path']).relative_to(data_path).__str__()
379+
cam_info.update(camera_intrinsics=camera_intrinsics)
380+
info["cams"].update({cam: cam_info})
381+
294382

295383
sample_data_token = sample['data'][chan]
296384
curr_sd_rec = nusc.get('sample_data', sample_data_token)

0 commit comments

Comments
 (0)