12
12
import pandas as pd
13
13
14
14
from ..dataset import DatasetTemplate
15
- from .argo2_utils .so3 import yaw_to_quat
15
+ from .argo2_utils .so3 import yaw_to_quat , quat_to_yaw
16
16
from .argo2_utils .constants import LABEL_ATTR
17
17
18
18
19
19
def process_single_segment (segment_path , split , info_list , ts2idx , output_dir , save_bin ):
20
20
test_mode = 'test' in split
21
21
if not test_mode :
22
- segment_anno = read_feather (osp .join (segment_path , 'annotations.feather' ))
22
+ segment_anno = read_feather (Path ( osp .join (segment_path , 'annotations.feather' ) ))
23
23
segname = segment_path .split ('/' )[- 1 ]
24
24
25
25
frame_path_list = os .listdir (osp .join (segment_path , 'sensors/lidar/' ))
@@ -70,17 +70,7 @@ def process_and_save_frame(frame_path, frame_anno, ts2idx, segname, output_dir,
70
70
cuboid_params = torch .from_numpy (cuboid_params )
71
71
yaw = quat_to_yaw (cuboid_params [:, - 4 :])
72
72
xyz = cuboid_params [:, :3 ]
73
- wlh = cuboid_params [:, [4 , 3 , 5 ]]
74
-
75
- yaw = - yaw - 0.5 * np .pi
76
-
77
- while (yaw < - np .pi ).any ():
78
- yaw [yaw < - np .pi ] += 2 * np .pi
79
-
80
- while (yaw > np .pi ).any ():
81
- yaw [yaw > np .pi ] -= 2 * np .pi
82
-
83
- # bbox = torch.cat([xyz, wlh, yaw.unsqueeze(1)], dim=1).numpy()
73
+ lwh = cuboid_params [:, [3 , 4 , 5 ]]
84
74
85
75
cat = frame_anno ['category' ].to_numpy ().tolist ()
86
76
cat = [c .lower ().capitalize () for c in cat ]
@@ -93,7 +83,7 @@ def process_and_save_frame(frame_path, frame_anno, ts2idx, segname, output_dir,
93
83
annos ['truncated' ] = np .zeros (num_obj , dtype = np .float64 )
94
84
annos ['occluded' ] = np .zeros (num_obj , dtype = np .int64 )
95
85
annos ['alpha' ] = - 10 * np .ones (num_obj , dtype = np .float64 )
96
- annos ['dimensions' ] = wlh .numpy ().astype (np .float64 )
86
+ annos ['dimensions' ] = lwh .numpy ().astype (np .float64 )
97
87
annos ['location' ] = xyz .numpy ().astype (np .float64 )
98
88
annos ['rotation_y' ] = yaw .numpy ().astype (np .float64 )
99
89
annos ['index' ] = np .arange (num_obj , dtype = np .int32 )
@@ -111,7 +101,7 @@ def process_and_save_frame(frame_path, frame_anno, ts2idx, segname, output_dir,
111
101
112
102
113
103
def save_point_cloud (frame_path , save_path ):
114
- lidar = read_feather (frame_path )
104
+ lidar = read_feather (Path ( frame_path ) )
115
105
lidar = lidar .loc [:, ['x' , 'y' , 'z' , 'intensity' ]].to_numpy ().astype (np .float32 )
116
106
lidar .tofile (save_path )
117
107
@@ -375,9 +365,9 @@ def format_results(self,
375
365
assert len (self .argo2_infos ) == len (outputs )
376
366
num_samples = len (outputs )
377
367
print ('\n Got {} samples' .format (num_samples ))
378
-
368
+
379
369
serialized_dts_list = []
380
-
370
+
381
371
print ('\n Convert predictions to Argoverse 2 format' )
382
372
for i in range (num_samples ):
383
373
out_i = outputs [i ]
@@ -394,7 +384,7 @@ def format_results(self,
394
384
serialized_dts ["timestamp_ns" ] = int (ts )
395
385
serialized_dts ["category" ] = category
396
386
serialized_dts_list .append (serialized_dts )
397
-
387
+
398
388
dts = (
399
389
pd .concat (serialized_dts_list )
400
390
.set_index (["log_id" , "timestamp_ns" ])
@@ -411,19 +401,13 @@ def format_results(self,
411
401
412
402
dts = dts .set_index (["log_id" , "timestamp_ns" ]).sort_index ()
413
403
414
- return dts
415
-
404
+ return dts
405
+
416
406
def lidar_box_to_argo2 (self , boxes ):
417
407
boxes = torch .Tensor (boxes )
418
408
cnt_xyz = boxes [:, :3 ]
419
- lwh = boxes [:, [4 , 3 , 5 ]]
420
- yaw = boxes [:, 6 ] #- np.pi/2
421
-
422
- yaw = - yaw - 0.5 * np .pi
423
- while (yaw < - np .pi ).any ():
424
- yaw [yaw < - np .pi ] += 2 * np .pi
425
- while (yaw > np .pi ).any ():
426
- yaw [yaw > np .pi ] -= 2 * np .pi
409
+ lwh = boxes [:, [3 , 4 , 5 ]]
410
+ yaw = boxes [:, 6 ]
427
411
428
412
quat = yaw_to_quat (yaw )
429
413
argo_cuboid = torch .cat ([cnt_xyz , lwh , quat ], dim = 1 )
@@ -470,7 +454,7 @@ def evaluation(self,
470
454
dts = self .format_results (results , class_names , pklfile_prefix , submission_prefix )
471
455
argo2_root = self .root_path
472
456
val_anno_path = osp .join (argo2_root , 'val_anno.feather' )
473
- gts = read_feather (val_anno_path )
457
+ gts = read_feather (Path ( val_anno_path ) )
474
458
gts = gts .set_index (["log_id" , "timestamp_ns" ]).sort_values ("category" )
475
459
476
460
valid_uuids_gts = gts .index .tolist ()
@@ -508,6 +492,13 @@ def parse_config():
508
492
args = parser .parse_args ()
509
493
return args
510
494
495
+ def main (seg_path_list , seg_split_list , info_list , ts2idx , output_dir , save_bin , token , num_process ):
496
+ for seg_i , seg_path in enumerate (seg_path_list ):
497
+ if seg_i % num_process != token :
498
+ continue
499
+ print (f'processing segment: { seg_i } /{ len (seg_path_list )} ' )
500
+ split = seg_split_list [seg_i ]
501
+ process_single_segment (seg_path , split , info_list , ts2idx , output_dir , save_bin )
511
502
512
503
if __name__ == '__main__' :
513
504
args = parse_config ()
@@ -559,4 +550,5 @@ def parse_config():
559
550
seg_anno_list .append (seg_anno )
560
551
561
552
gts = pd .concat (seg_anno_list ).reset_index ()
562
- gts .to_feather (val_seg_path_list )
553
+ gts .to_feather (save_feather_path )
554
+
0 commit comments