Skip to content

Commit

Permalink
support nuscene evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
sshaoshuai committed Jul 6, 2020
1 parent c7f6de3 commit 950c671
Show file tree
Hide file tree
Showing 6 changed files with 345 additions and 8 deletions.
97 changes: 97 additions & 0 deletions pcdet/datasets/nuscenes/nuscenes_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __getitem__(self, index):
input_dict = {
'points': points,
'frame_id': Path(info['lidar_path']).stem,
'metadata': {'token': info['token']}
}

if 'gt_boxes' in info:
Expand All @@ -92,6 +93,102 @@ def __getitem__(self, index):

return data_dict

@staticmethod
def generate_prediction_dicts(batch_dict, pred_dicts, class_names, output_path=None):
"""
Args:
batch_dict:
frame_id:
pred_dicts: list of pred_dicts
pred_boxes: (N, 7), Tensor
pred_scores: (N), Tensor
pred_labels: (N), Tensor
class_names:
output_path:
Returns:
"""
def get_template_prediction(num_samples):
ret_dict = {
'name': np.zeros(num_samples), 'score': np.zeros(num_samples),
'boxes_lidar': np.zeros([num_samples, 7])
}
return ret_dict

def generate_single_sample_dict(box_dict):
pred_scores = box_dict['pred_scores'].cpu().numpy()
pred_boxes = box_dict['pred_boxes'].cpu().numpy()
pred_labels = box_dict['pred_labels'].cpu().numpy()
pred_dict = get_template_prediction(pred_scores.shape[0])
if pred_scores.shape[0] == 0:
return pred_dict

pred_dict['name'] = np.array(class_names)[pred_labels - 1]
pred_dict['score'] = pred_scores
pred_dict['boxes_lidar'] = pred_boxes
pred_dict['pred_labels'] = pred_labels

return pred_dict

annos = []
for index, box_dict in enumerate(pred_dicts):
single_pred_dict = generate_single_sample_dict(box_dict)
single_pred_dict['frame_id'] = batch_dict['frame_id'][index]
single_pred_dict['metadata'] = batch_dict['metadata'][index]
annos.append(single_pred_dict)

return annos

def evaluation(self, det_annos, class_names, **kwargs):
import json
from nuscenes.nuscenes import NuScenes
from . import nuscenes_utils
nusc = NuScenes(version=self.dataset_cfg.VERSION, dataroot=str(self.root_path), verbose=True)
nusc_annos = nuscenes_utils.transform_det_annos_to_nusc_annos(det_annos, nusc)
nusc_annos['meta'] = {
'use_camera': False,
'use_lidar': True,
'use_radar': False,
'use_map': False,
'use_external': False,
}

output_path = Path(kwargs['output_path'])
output_path.mkdir(exist_ok=True, parents=True)
res_path = str(output_path / 'results_nusc.json')
with open(res_path, 'w') as f:
json.dump(nusc_annos, f)

self.logger.info(f'The predictions of NuScenes have been saved to {res_path}')

if self.dataset_cfg.VERSION == 'v1.0-test':
return 'No ground-truth annotations for evaluation', {}

from nuscenes.eval.detection.config import config_factory
from nuscenes.eval.detection.evaluate import NuScenesEval

eval_version = 'cvpr_2019'
eval_set_map = {
'v1.0-mini': 'mini_val',
'v1.0-trainval': 'val',
'v1.0-test': 'test'
}

nusc_eval = NuScenesEval(
nusc,
config=config_factory(eval_version),
result_path=res_path,
eval_set=eval_set_map[self.dataset_cfg.VERSION],
output_dir=str(output_path),
verbose=True,
)
metrics_summary = nusc_eval.main(plot_examples=2, render_curves=False)

with open(output_path / 'metrics_summary.json', 'r') as f:
metrics = json.load(f)

result_str, result_dict = nuscenes_utils.format_nuscene_results(metrics, self.class_names, version=eval_version)
return result_str, result_dict

def create_groundtruth_database(self, used_classes=None, max_sweeps=10):
import torch

Expand Down
238 changes: 237 additions & 1 deletion pcdet/datasets/nuscenes/nuscenes_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
"""
The NuScenes data pre-processing is borrowed from
The NuScenes data pre-processing and evaluation is modified from
https://github.com/traveller59/second.pytorch and https://github.com/poodarchu/Det3D
"""

from pathlib import Path
import tqdm
import numpy as np
import operator
from functools import reduce
from nuscenes.utils.geometry_utils import transform_matrix
from pyquaternion import Quaternion
from nuscenes.utils.data_classes import Box


map_name_from_general_to_detection = {
Expand Down Expand Up @@ -38,6 +40,120 @@
}


cls_attr_dist = {
'barrier': {
'cycle.with_rider': 0,
'cycle.without_rider': 0,
'pedestrian.moving': 0,
'pedestrian.sitting_lying_down': 0,
'pedestrian.standing': 0,
'vehicle.moving': 0,
'vehicle.parked': 0,
'vehicle.stopped': 0,
},
'bicycle': {
'cycle.with_rider': 2791,
'cycle.without_rider': 8946,
'pedestrian.moving': 0,
'pedestrian.sitting_lying_down': 0,
'pedestrian.standing': 0,
'vehicle.moving': 0,
'vehicle.parked': 0,
'vehicle.stopped': 0,
},
'bus': {
'cycle.with_rider': 0,
'cycle.without_rider': 0,
'pedestrian.moving': 0,
'pedestrian.sitting_lying_down': 0,
'pedestrian.standing': 0,
'vehicle.moving': 9092,
'vehicle.parked': 3294,
'vehicle.stopped': 3881,
},
'car': {
'cycle.with_rider': 0,
'cycle.without_rider': 0,
'pedestrian.moving': 0,
'pedestrian.sitting_lying_down': 0,
'pedestrian.standing': 0,
'vehicle.moving': 114304,
'vehicle.parked': 330133,
'vehicle.stopped': 46898,
},
'construction_vehicle': {
'cycle.with_rider': 0,
'cycle.without_rider': 0,
'pedestrian.moving': 0,
'pedestrian.sitting_lying_down': 0,
'pedestrian.standing': 0,
'vehicle.moving': 882,
'vehicle.parked': 11549,
'vehicle.stopped': 2102,
},
'ignore': {
'cycle.with_rider': 307,
'cycle.without_rider': 73,
'pedestrian.moving': 0,
'pedestrian.sitting_lying_down': 0,
'pedestrian.standing': 0,
'vehicle.moving': 165,
'vehicle.parked': 400,
'vehicle.stopped': 102,
},
'motorcycle': {
'cycle.with_rider': 4233,
'cycle.without_rider': 8326,
'pedestrian.moving': 0,
'pedestrian.sitting_lying_down': 0,
'pedestrian.standing': 0,
'vehicle.moving': 0,
'vehicle.parked': 0,
'vehicle.stopped': 0,
},
'pedestrian': {
'cycle.with_rider': 0,
'cycle.without_rider': 0,
'pedestrian.moving': 157444,
'pedestrian.sitting_lying_down': 13939,
'pedestrian.standing': 46530,
'vehicle.moving': 0,
'vehicle.parked': 0,
'vehicle.stopped': 0,
},
'traffic_cone': {
'cycle.with_rider': 0,
'cycle.without_rider': 0,
'pedestrian.moving': 0,
'pedestrian.sitting_lying_down': 0,
'pedestrian.standing': 0,
'vehicle.moving': 0,
'vehicle.parked': 0,
'vehicle.stopped': 0,
},
'trailer': {
'cycle.with_rider': 0,
'cycle.without_rider': 0,
'pedestrian.moving': 0,
'pedestrian.sitting_lying_down': 0,
'pedestrian.standing': 0,
'vehicle.moving': 3421,
'vehicle.parked': 19224,
'vehicle.stopped': 1895,
},
'truck': {
'cycle.with_rider': 0,
'cycle.without_rider': 0,
'pedestrian.moving': 0,
'pedestrian.sitting_lying_down': 0,
'pedestrian.standing': 0,
'vehicle.moving': 21339,
'vehicle.parked': 55626,
'vehicle.stopped': 11097,
},
}


def get_available_scenes(nusc):
available_scenes = []
print('total scene num:', len(nusc.scene))
Expand Down Expand Up @@ -259,3 +375,123 @@ def fill_trainval_infos(data_path, nusc, train_scenes, val_scenes, test=False, m

progress_bar.close()
return train_nusc_infos, val_nusc_infos


def boxes_lidar_to_nusenes(det_info):
boxes3d = det_info['boxes_lidar']
scores = det_info['score']
labels = det_info['pred_labels']

box_list = []
for k in range(boxes3d.shape[0]):
quat = Quaternion(axis=[0, 0, 1], radians=boxes3d[k, 6])
velocity = (*boxes3d[k, 7:9], 0.0) if boxes3d.shape[1] == 9 else (0.0, 0.0, 0.0)
box = Box(
boxes3d[k, :3],
boxes3d[k, [4, 3, 5]], # wlh
quat, label=labels[k], score=scores[k], velocity=velocity,
)
box_list.append(box)
return box_list


def lidar_nusc_box_to_global(nusc, boxes, sample_token):
s_record = nusc.get('sample', sample_token)
sample_data_token = s_record['data']['LIDAR_TOP']

sd_record = nusc.get('sample_data', sample_data_token)
cs_record = nusc.get('calibrated_sensor', sd_record['calibrated_sensor_token'])
sensor_record = nusc.get('sensor', cs_record['sensor_token'])
pose_record = nusc.get('ego_pose', sd_record['ego_pose_token'])

data_path = nusc.get_sample_data_path(sample_data_token)
box_list = []
for box in boxes:
# Move box to ego vehicle coord system
box.rotate(Quaternion(cs_record['rotation']))
box.translate(np.array(cs_record['translation']))
# Move box to global coord system
box.rotate(Quaternion(pose_record['rotation']))
box.translate(np.array(pose_record['translation']))
box_list.append(box)
return box_list


def transform_det_annos_to_nusc_annos(det_annos, nusc):
nusc_annos = {
'results': {},
'meta': None,
}

for det in det_annos:
annos = []
box_list = boxes_lidar_to_nusenes(det)
box_list = lidar_nusc_box_to_global(
nusc=nusc, boxes=box_list, sample_token=det['metadata']['token']
)

for k, box in enumerate(box_list):
name = det['name'][k]
if np.sqrt(box.velocity[0] ** 2 + box.velocity[1] ** 2) > 0.2:
if name in ['car', 'construction_vehicle', 'bus', 'truck', 'trailer']:
attr = 'vehicle.moving'
elif name in ['bicycle', 'motorcycle']:
attr = 'cycle.with_rider'
else:
attr = None
else:
if name in ['pedestrian']:
attr = 'pedestrian.standing'
elif name in ['bus']:
attr = 'vehicle.stopped'
else:
attr = None
attr = attr if attr is not None else max(
cls_attr_dist[name].items(), key=operator.itemgetter(1))[0]
nusc_anno = {
'sample_token': det['metadata']['token'],
'translation': box.center.tolist(),
'size': box.wlh.tolist(),
'rotation': box.orientation.elements.tolist(),
'velocity': box.velocity[:2].tolist(),
'detection_name': name,
'detection_score': box.score,
'attribute_name': attr
}
annos.append(nusc_anno)

nusc_annos['results'].update({det["metadata"]["token"]: annos})

return nusc_annos


def format_nuscene_results(metrics, class_names, version='default'):
result = '----------------Nuscene %s results-----------------\n' % version
for name in class_names:
threshs = ', '.join(list(metrics['label_aps'][name].keys()))
ap_list = list(metrics['label_aps'][name].values())

err_name =', '.join([x.split('_')[0] for x in list(metrics['label_tp_errors'][name].keys())])
error_list = list(metrics['label_tp_errors'][name].values())

result += f'***{name} error@{err_name} | AP@{threshs}\n'
result += ', '.join(['%.2f' % x for x in error_list]) + ' | '
result += ', '.join(['%.2f' % (x * 100) for x in ap_list])
result += f" | mean AP: {metrics['mean_dist_aps'][name]}"
result += '\n'

result += '--------------average performance-------------\n'
details = {}
for key, val in metrics['tp_errors'].items():
result += '%s:\t %.4f\n' % (key, val)
details[key] = val

result += 'mAP:\t %.4f\n' % metrics['mean_ap']
result += 'NDS:\t %.4f\n' % metrics['nd_score']

details.update({
'mAP': metrics['mean_ap'],
'NDS': metrics['nd_score'],
})

return result, details
7 changes: 6 additions & 1 deletion tools/cfgs/dataset_configs/nuscenes_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@ VERSION: 'v1.0-trainval'
MAX_SWEEPS: 10
PRED_VELOCITY: False

DATA_SPLIT: {
'train': train,
'test': val
}

INFO_PATH: {
'train': [nuscenes_infos_10sweeps_train.pkl],
'test': [nuscenes_infos_10sweeps_train.pkl],
'test': [nuscenes_infos_10sweeps_val.pkl],
}

POINT_CLOUD_RANGE: [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
Expand Down
2 changes: 1 addition & 1 deletion tools/cfgs/nuscenes_models/second.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ MODEL:
NMS_CONFIG:
MULTI_CLASSES_NMS: False
NMS_TYPE: nms_gpu
NMS_THRESH: 0.01
NMS_THRESH: 0.2
NMS_PRE_MAXSIZE: 4096
NMS_POST_MAXSIZE: 500

Expand Down
Loading

0 comments on commit 950c671

Please sign in to comment.