Skip to content

Commit 81763e7

Browse files
authored
Fixed open-mmlab#1323 and some minor issues (open-mmlab#1325)
1 parent 4d8624f commit 81763e7

File tree

1 file changed

+22
-30
lines changed

1 file changed

+22
-30
lines changed

pcdet/datasets/argo2/argo2_dataset.py

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
import pandas as pd
1313

1414
from ..dataset import DatasetTemplate
15-
from .argo2_utils.so3 import yaw_to_quat
15+
from .argo2_utils.so3 import yaw_to_quat, quat_to_yaw
1616
from .argo2_utils.constants import LABEL_ATTR
1717

1818

1919
def process_single_segment(segment_path, split, info_list, ts2idx, output_dir, save_bin):
2020
test_mode = 'test' in split
2121
if not test_mode:
22-
segment_anno = read_feather(osp.join(segment_path, 'annotations.feather'))
22+
segment_anno = read_feather(Path(osp.join(segment_path, 'annotations.feather')))
2323
segname = segment_path.split('/')[-1]
2424

2525
frame_path_list = os.listdir(osp.join(segment_path, 'sensors/lidar/'))
@@ -70,17 +70,7 @@ def process_and_save_frame(frame_path, frame_anno, ts2idx, segname, output_dir,
7070
cuboid_params = torch.from_numpy(cuboid_params)
7171
yaw = quat_to_yaw(cuboid_params[:, -4:])
7272
xyz = cuboid_params[:, :3]
73-
wlh = cuboid_params[:, [4, 3, 5]]
74-
75-
yaw = -yaw - 0.5 * np.pi
76-
77-
while (yaw < -np.pi).any():
78-
yaw[yaw < -np.pi] += 2 * np.pi
79-
80-
while (yaw > np.pi).any():
81-
yaw[yaw > np.pi] -= 2 * np.pi
82-
83-
# bbox = torch.cat([xyz, wlh, yaw.unsqueeze(1)], dim=1).numpy()
73+
lwh = cuboid_params[:, [3, 4, 5]]
8474

8575
cat = frame_anno['category'].to_numpy().tolist()
8676
cat = [c.lower().capitalize() for c in cat]
@@ -93,7 +83,7 @@ def process_and_save_frame(frame_path, frame_anno, ts2idx, segname, output_dir,
9383
annos['truncated'] = np.zeros(num_obj, dtype=np.float64)
9484
annos['occluded'] = np.zeros(num_obj, dtype=np.int64)
9585
annos['alpha'] = -10 * np.ones(num_obj, dtype=np.float64)
96-
annos['dimensions'] = wlh.numpy().astype(np.float64)
86+
annos['dimensions'] = lwh.numpy().astype(np.float64)
9787
annos['location'] = xyz.numpy().astype(np.float64)
9888
annos['rotation_y'] = yaw.numpy().astype(np.float64)
9989
annos['index'] = np.arange(num_obj, dtype=np.int32)
@@ -111,7 +101,7 @@ def process_and_save_frame(frame_path, frame_anno, ts2idx, segname, output_dir,
111101

112102

113103
def save_point_cloud(frame_path, save_path):
114-
lidar = read_feather(frame_path)
104+
lidar = read_feather(Path(frame_path))
115105
lidar = lidar.loc[:, ['x', 'y', 'z', 'intensity']].to_numpy().astype(np.float32)
116106
lidar.tofile(save_path)
117107

@@ -375,9 +365,9 @@ def format_results(self,
375365
assert len(self.argo2_infos) == len(outputs)
376366
num_samples = len(outputs)
377367
print('\nGot {} samples'.format(num_samples))
378-
368+
379369
serialized_dts_list = []
380-
370+
381371
print('\nConvert predictions to Argoverse 2 format')
382372
for i in range(num_samples):
383373
out_i = outputs[i]
@@ -394,7 +384,7 @@ def format_results(self,
394384
serialized_dts["timestamp_ns"] = int(ts)
395385
serialized_dts["category"] = category
396386
serialized_dts_list.append(serialized_dts)
397-
387+
398388
dts = (
399389
pd.concat(serialized_dts_list)
400390
.set_index(["log_id", "timestamp_ns"])
@@ -411,19 +401,13 @@ def format_results(self,
411401

412402
dts = dts.set_index(["log_id", "timestamp_ns"]).sort_index()
413403

414-
return dts
415-
404+
return dts
405+
416406
def lidar_box_to_argo2(self, boxes):
417407
boxes = torch.Tensor(boxes)
418408
cnt_xyz = boxes[:, :3]
419-
lwh = boxes[:, [4, 3, 5]]
420-
yaw = boxes[:, 6] #- np.pi/2
421-
422-
yaw = -yaw - 0.5 * np.pi
423-
while (yaw < -np.pi).any():
424-
yaw[yaw < -np.pi] += 2 * np.pi
425-
while (yaw > np.pi).any():
426-
yaw[yaw > np.pi] -= 2 * np.pi
409+
lwh = boxes[:, [3, 4, 5]]
410+
yaw = boxes[:, 6]
427411

428412
quat = yaw_to_quat(yaw)
429413
argo_cuboid = torch.cat([cnt_xyz, lwh, quat], dim=1)
@@ -470,7 +454,7 @@ def evaluation(self,
470454
dts = self.format_results(results, class_names, pklfile_prefix, submission_prefix)
471455
argo2_root = self.root_path
472456
val_anno_path = osp.join(argo2_root, 'val_anno.feather')
473-
gts = read_feather(val_anno_path)
457+
gts = read_feather(Path(val_anno_path))
474458
gts = gts.set_index(["log_id", "timestamp_ns"]).sort_values("category")
475459

476460
valid_uuids_gts = gts.index.tolist()
@@ -508,6 +492,13 @@ def parse_config():
508492
args = parser.parse_args()
509493
return args
510494

495+
def main(seg_path_list, seg_split_list, info_list, ts2idx, output_dir, save_bin, token, num_process):
496+
for seg_i, seg_path in enumerate(seg_path_list):
497+
if seg_i % num_process != token:
498+
continue
499+
print(f'processing segment: {seg_i}/{len(seg_path_list)}')
500+
split = seg_split_list[seg_i]
501+
process_single_segment(seg_path, split, info_list, ts2idx, output_dir, save_bin)
511502

512503
if __name__ == '__main__':
513504
args = parse_config()
@@ -559,4 +550,5 @@ def parse_config():
559550
seg_anno_list.append(seg_anno)
560551

561552
gts = pd.concat(seg_anno_list).reset_index()
562-
gts.to_feather(val_seg_path_list)
553+
gts.to_feather(save_feather_path)
554+

0 commit comments

Comments
 (0)