Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 51843ef

Browse files
committedOct 19, 2023
commit
1 parent 81cab83 commit 51843ef

File tree

7 files changed

+505
-204
lines changed

7 files changed

+505
-204
lines changed
 

‎pcdet/datasets/kitti/kitti_dataset.py

+41-115
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33

44
import numpy as np
55
from skimage import io
6-
from tqdm import tqdm
76

8-
from pcdet.datasets.kitti import kitti_utils
9-
from pcdet.ops.roiaware_pool3d import roiaware_pool3d_utils
10-
from pcdet.utils import box_utils, calibration_kitti, common_utils, object3d_kitti
11-
from pcdet.datasets.dataset import DatasetTemplate
7+
from . import kitti_utils
8+
from ...ops.roiaware_pool3d import roiaware_pool3d_utils
9+
from ...utils import box_utils, calibration_kitti, common_utils, object3d_kitti
10+
from ..dataset import DatasetTemplate
1211

1312

1413
class KittiDataset(DatasetTemplate):
@@ -60,7 +59,6 @@ def set_split(self, split):
6059

6160
split_dir = self.root_path / 'ImageSets' / (self.split + '.txt')
6261
self.sample_id_list = [x.strip() for x in open(split_dir).readlines()] if split_dir.exists() else None
63-
print(f'Number of {self.split} samples: {len(self.sample_id_list)}')
6462

6563
def get_lidar(self, idx):
6664
lidar_file = self.root_split_path / 'velodyne' / ('%s.bin' % idx)
@@ -152,8 +150,8 @@ def get_fov_flag(pts_rect, img_shape, calib):
152150
def get_infos(self, num_workers=4, has_label=True, count_inside_pts=True, sample_id_list=None):
153151
import concurrent.futures as futures
154152

155-
def process_single_scene(sample_idx, is_empty=False):
156-
# print('%s sample_idx: %s' % (self.split, sample_idx))
153+
def process_single_scene(sample_idx):
154+
print('%s sample_idx: %s' % (self.split, sample_idx))
157155
info = {}
158156
pc_info = {'num_features': 4, 'lidar_idx': sample_idx}
159157
info['point_cloud'] = pc_info
@@ -173,25 +171,14 @@ def process_single_scene(sample_idx, is_empty=False):
173171

174172
if has_label:
175173
obj_list = self.get_label(sample_idx)
176-
177-
if len(obj_list) <= 0:
178-
is_empty = True
179-
return info, is_empty
180-
181174
annotations = {}
182175
annotations['name'] = np.array([obj.cls_type for obj in obj_list])
183176
annotations['truncated'] = np.array([obj.truncation for obj in obj_list])
184177
annotations['occluded'] = np.array([obj.occlusion for obj in obj_list])
185178
annotations['alpha'] = np.array([obj.alpha for obj in obj_list])
186-
if len(obj_list) > 0:
187-
annotations['bbox'] = np.concatenate([obj.box2d.reshape(1, 4) for obj in obj_list], axis=0)
188-
else:
189-
annotations['bbox'] = np.array([])
179+
annotations['bbox'] = np.concatenate([obj.box2d.reshape(1, 4) for obj in obj_list], axis=0)
190180
annotations['dimensions'] = np.array([[obj.l, obj.h, obj.w] for obj in obj_list]) # lhw(camera) format
191-
if len(obj_list) > 0:
192-
annotations['location'] = np.concatenate([obj.loc.reshape(1, 3) for obj in obj_list], axis=0)
193-
else:
194-
annotations['location'] = np.array([])
181+
annotations['location'] = np.concatenate([obj.loc.reshape(1, 3) for obj in obj_list], axis=0)
195182
annotations['rotation_y'] = np.array([obj.ry for obj in obj_list])
196183
annotations['score'] = np.array([obj.score for obj in obj_list])
197184
annotations['difficulty'] = np.array([obj.level for obj in obj_list], np.int32)
@@ -204,18 +191,15 @@ def process_single_scene(sample_idx, is_empty=False):
204191
loc = annotations['location'][:num_objects]
205192
dims = annotations['dimensions'][:num_objects]
206193
rots = annotations['rotation_y'][:num_objects]
207-
if len(obj_list) > 0:
208-
loc_lidar = calib.rect_to_lidar(loc)
209-
l, h, w = dims[:, 0:1], dims[:, 1:2], dims[:, 2:3]
210-
loc_lidar[:, 2] += h[:, 0] / 2
211-
gt_boxes_lidar = np.concatenate([loc_lidar, l, w, h, -(np.pi / 2 + rots[..., np.newaxis])], axis=1)
212-
annotations['gt_boxes_lidar'] = gt_boxes_lidar
213-
else:
214-
annotations['gt_boxes_lidar'] = np.array([[], [], [], [], [], [], []]).reshape(-1, 7)
194+
loc_lidar = calib.rect_to_lidar(loc)
195+
l, h, w = dims[:, 0:1], dims[:, 1:2], dims[:, 2:3]
196+
loc_lidar[:, 2] += h[:, 0] / 2
197+
gt_boxes_lidar = np.concatenate([loc_lidar, l, w, h, -(np.pi / 2 + rots[..., np.newaxis])], axis=1)
198+
annotations['gt_boxes_lidar'] = gt_boxes_lidar
215199

216200
info['annos'] = annotations
217201

218-
if count_inside_pts and len(obj_list) > 0:
202+
if count_inside_pts:
219203
points = self.get_lidar(sample_idx)
220204
calib = self.get_calib(sample_idx)
221205
pts_rect = calib.lidar_to_rect(points[:, 0:3])
@@ -229,53 +213,28 @@ def process_single_scene(sample_idx, is_empty=False):
229213
flag = box_utils.in_hull(pts_fov[:, 0:3], corners_lidar[k])
230214
num_points_in_gt[k] = flag.sum()
231215
annotations['num_points_in_gt'] = num_points_in_gt
232-
233-
return info, is_empty
216+
217+
return info
234218

235219
sample_id_list = sample_id_list if sample_id_list is not None else self.sample_id_list
236-
237-
infos = []
238-
empty = []
239-
non_empty = []
240-
for sample in tqdm(sample_id_list):
241-
info, is_empty = process_single_scene(sample)
242-
if is_empty:
243-
empty.append(str(sample))
244-
continue
245-
else:
246-
non_empty.append(str(sample))
247-
infos.append(info)
248-
249-
# Save empty IDs in a txt
250-
with open('/home/ipl-pc/cmkd/data/kitti/empty_train_ids.txt', 'wt') as f:
251-
for id in empty:
252-
f.write(id + '\n')
253-
f.close()
254-
255-
with open('/home/ipl-pc/cmkd/data/kitti/non_empty_train_ids.txt', 'wt') as f:
256-
for id in non_empty:
257-
f.write(id + '\n')
258-
f.close()
259-
260-
# with futures.ThreadPoolExecutor(num_workers) as executor:
261-
# infos = executor.map(process_single_scene, sample_id_list)
262-
220+
with futures.ThreadPoolExecutor(num_workers) as executor:
221+
infos = executor.map(process_single_scene, sample_id_list)
263222
return list(infos)
264223

265224
def create_groundtruth_database(self, info_path=None, used_classes=None, split='train'):
266225
import torch
267-
save_path = '/home/ipl-pc/cmkd/data/kitti'
268-
database_save_path = Path(save_path) / ('gt_database' if split == 'train' else ('gt_database_%s' % split))
269-
db_info_save_path = Path(save_path) / ('kitti_dbinfos_%s.pkl' % split)
226+
227+
database_save_path = Path(self.root_path) / ('gt_database' if split == 'train' else ('gt_database_%s' % split))
228+
db_info_save_path = Path(self.root_path) / ('kitti_dbinfos_%s.pkl' % split)
270229

271230
database_save_path.mkdir(parents=True, exist_ok=True)
272231
all_db_infos = {}
273232

274233
with open(info_path, 'rb') as f:
275234
infos = pickle.load(f)
276235

277-
for k in tqdm(range(len(infos))):
278-
# print('gt_database sample: %d/%d' % (k + 1, len(infos)))
236+
for k in range(len(infos)):
237+
print('gt_database sample: %d/%d' % (k + 1, len(infos)))
279238
info = infos[k]
280239
sample_idx = info['point_cloud']['lidar_idx']
281240
points = self.get_lidar(sample_idx)
@@ -300,7 +259,7 @@ def create_groundtruth_database(self, info_path=None, used_classes=None, split='
300259
gt_points.tofile(f)
301260

302261
if (used_classes is None) or names[i] in used_classes:
303-
db_path = str(filepath.relative_to(save_path)) # gt_database/xxxxx.bin
262+
db_path = str(filepath.relative_to(self.root_path)) # gt_database/xxxxx.bin
304263
db_info = {'name': names[i], 'path': db_path, 'image_idx': sample_idx, 'gt_idx': i,
305264
'box3d_lidar': gt_boxes[i], 'num_points_in_gt': gt_points.shape[0],
306265
'difficulty': difficulty[i], 'bbox': bbox[i], 'score': annos['score'][i]}
@@ -472,50 +431,34 @@ def create_kitti_infos(dataset_cfg, class_names, data_path, save_path, workers=4
472431
dataset = KittiDataset(dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path, training=False)
473432
train_split, val_split = 'train', 'val'
474433

475-
train_filename = save_path / ('kitti_infos_%s_lpcg.pkl' % train_split)
476-
val_filename = save_path / ('kitti_infos_%s_lpcg.pkl' % val_split)
477-
# trainval_filename = save_path / 'kitti_infos_trainval_raw.pkl'
478-
# test_filename = save_path / 'kitti_infos_test_raw.pkl'
479-
480-
### Remove train ids with empty .txt
481-
# dataset.set_split(train_split) #['000000', '000001', ...]
482-
# non_empty_ids = []
483-
# for sample_idx in dataset.sample_id_list:
484-
# obj_list = dataset.get_label(sample_idx)
485-
# if len(obj_list) <= 0:
486-
# continue
487-
# non_empty_ids.append(sample_idx) #31794
488-
# # Save empty IDs in a txt
489-
# with open('/home/ipl-pc/cmkd/data/kitti/train_no_empty.txt', 'wt') as f:
490-
# for id in non_empty_ids:
491-
# f.write(id + '\n')
492-
# f.close()
493-
# print('Cleaned.')
494-
# exit()
495-
####
496-
434+
train_filename = save_path / ('kitti_infos_%s.pkl' % train_split)
435+
val_filename = save_path / ('kitti_infos_%s.pkl' % val_split)
436+
trainval_filename = save_path / 'kitti_infos_trainval.pkl'
437+
test_filename = save_path / 'kitti_infos_test.pkl'
438+
497439
print('---------------Start to generate data infos---------------')
440+
498441
dataset.set_split(train_split)
499442
kitti_infos_train = dataset.get_infos(num_workers=workers, has_label=True, count_inside_pts=True)
500443
with open(train_filename, 'wb') as f:
501444
pickle.dump(kitti_infos_train, f)
502-
print('Kitti info train_raw file is saved to %s' % train_filename)
445+
print('Kitti info train file is saved to %s' % train_filename)
503446

504447
dataset.set_split(val_split)
505448
kitti_infos_val = dataset.get_infos(num_workers=workers, has_label=True, count_inside_pts=True)
506449
with open(val_filename, 'wb') as f:
507450
pickle.dump(kitti_infos_val, f)
508-
print('Kitti info val_raw file is saved to %s' % val_filename)
451+
print('Kitti info val file is saved to %s' % val_filename)
509452

510-
# with open(trainval_filename, 'wb') as f:
511-
# pickle.dump(kitti_infos_train + kitti_infos_val, f)
512-
# print('Kitti info trainval_raw file is saved to %s' % trainval_filename)
453+
with open(trainval_filename, 'wb') as f:
454+
pickle.dump(kitti_infos_train + kitti_infos_val, f)
455+
print('Kitti info trainval file is saved to %s' % trainval_filename)
513456

514-
# dataset.set_split('test')
515-
# kitti_infos_test = dataset.get_infos(num_workers=workers, has_label=False, count_inside_pts=False)
516-
# with open(test_filename, 'wb') as f:
517-
# pickle.dump(kitti_infos_test, f)
518-
# print('Kitti info test_raw file is saved to %s' % test_filename)
457+
dataset.set_split('test')
458+
kitti_infos_test = dataset.get_infos(num_workers=workers, has_label=False, count_inside_pts=False)
459+
with open(test_filename, 'wb') as f:
460+
pickle.dump(kitti_infos_test, f)
461+
print('Kitti info test file is saved to %s' % test_filename)
519462

520463
print('---------------Start create groundtruth database for data augmentation---------------')
521464
dataset.set_split(train_split)
@@ -526,7 +469,6 @@ def create_kitti_infos(dataset_cfg, class_names, data_path, save_path, workers=4
526469

527470
if __name__ == '__main__':
528471
import sys
529-
530472
if sys.argv.__len__() > 1 and sys.argv[1] == 'create_kitti_infos':
531473
import yaml
532474
from pathlib import Path
@@ -536,22 +478,6 @@ def create_kitti_infos(dataset_cfg, class_names, data_path, save_path, workers=4
536478
create_kitti_infos(
537479
dataset_cfg=dataset_cfg,
538480
class_names=['Car', 'Pedestrian', 'Cyclist'],
539-
data_path=Path('/mnt/disk2/Data/KITTI/lpcg'),
481+
data_path=ROOT_DIR / 'data' / 'kitti',
540482
save_path=ROOT_DIR / 'data' / 'kitti'
541483
)
542-
543-
# DEBUG MODE
544-
# import yaml
545-
# from pathlib import Path
546-
# from easydict import EasyDict
547-
# dataset_cfg = EasyDict(yaml.safe_load(open(sys.argv[4])))
548-
# ROOT_DIR = (Path(__file__).resolve().parent / '../../../').resolve()
549-
# create_kitti_infos(
550-
# dataset_cfg=dataset_cfg,
551-
# class_names=['Car', 'Pedestrian', 'Cyclist'],
552-
# data_path=Path('/mnt/disk2/Data/KITTI/lpcg'),
553-
# save_path=ROOT_DIR / 'data' / 'kitti'
554-
# )
555-
556-
557-
# python -m pcdet.datasets.kitti.kitti_dataset create_kitti_infos tools/cfgs/dataset_configs/kitti_dataset_raw.yaml
There was a problem loading the remainder of the diff.

0 commit comments

Comments
 (0)
Please sign in to comment.