From 27e4829327933c6a608d0ca22580354cd9598e9b Mon Sep 17 00:00:00 2001 From: Alex Lang Date: Mon, 13 May 2019 15:38:16 +0800 Subject: [PATCH] Add deserialize to DetectionMetrics (#146) * add deserialize to DetectionMetrics * Formatting --- .../nuscenes/eval/detection/data_classes.py | 38 ++++++++- .../eval/detection/tests/test_data_classes.py | 85 ++++++++++++++----- 2 files changed, 101 insertions(+), 22 deletions(-) diff --git a/python-sdk/nuscenes/eval/detection/data_classes.py b/python-sdk/nuscenes/eval/detection/data_classes.py index f5054903b..760a82f51 100644 --- a/python-sdk/nuscenes/eval/detection/data_classes.py +++ b/python-sdk/nuscenes/eval/detection/data_classes.py @@ -38,6 +38,12 @@ def __init__(self, self.class_names = self.class_range.keys() + def __eq__(self, other): + eq = True + for key in self.serialize().keys(): + eq = eq and np.array_equal(getattr(self, key), getattr(other, key)) + return eq + def serialize(self) -> dict: """ Serialize instance into json-friendly format. """ return { @@ -460,4 +466,34 @@ def serialize(self): 'tp_errors': self.tp_errors, 'tp_scores': self.tp_scores, 'nd_score': self.nd_score, - 'eval_time': self.eval_time} + 'eval_time': self.eval_time, + 'cfg': self.cfg.serialize()} + + @classmethod + def deserialize(cls, content): + """ Initialize from serialized dictionary. """ + + cfg = DetectionConfig.deserialize(content['cfg']) + + metrics = cls(cfg=cfg) + metrics.add_runtime(content['eval_time']) + + for detection_name, label_aps in content['label_aps'].items(): + for dist_th, ap in label_aps.items(): + metrics.add_label_ap(detection_name=detection_name, dist_th=float(dist_th), ap=float(ap)) + + for detection_name, label_tps in content['label_tp_errors'].items(): + for metric_name, tp in label_tps.items(): + metrics.add_label_tp(detection_name=detection_name, metric_name=metric_name, tp=float(tp)) + + return metrics + + def __eq__(self, other): + + eq = True + eq = eq and self._label_aps == other._label_aps + eq = eq and self._label_tp_errors == other._label_tp_errors + eq = eq and self.eval_time == other.eval_time + eq = eq and self.cfg == other.cfg + + return eq diff --git a/python-sdk/nuscenes/eval/detection/tests/test_data_classes.py b/python-sdk/nuscenes/eval/detection/tests/test_data_classes.py index efbf1ba4b..f09a453c4 100644 --- a/python-sdk/nuscenes/eval/detection/tests/test_data_classes.py +++ b/python-sdk/nuscenes/eval/detection/tests/test_data_classes.py @@ -1,17 +1,40 @@ # nuScenes dev-kit. -# Code written by Oscar Beijbom, 2019. +# Code written by Oscar Beijbom and Alex Lang, 2019. # Licensed under the Creative Commons [see licence.txt] import json import unittest -from nuscenes.eval.detection.data_classes import MetricData, EvalBox, EvalBoxes, MetricDataList +from nuscenes.eval.detection.constants import TP_METRICS +from nuscenes.eval.detection.data_classes import MetricData, EvalBox, EvalBoxes, MetricDataList, DetectionConfig, \ + DetectionMetrics + + +class TestEvalBox(unittest.TestCase): + + def test_serialization(self): + """ Test that instance serialization protocol works with json encoding. """ + box = EvalBox() + recovered = EvalBox.deserialize(json.loads(json.dumps(box.serialize()))) + self.assertEqual(box, recovered) + + +class TestEvalBoxes(unittest.TestCase): + + def test_serialization(self): + """ Test that instance serialization protocol works with json encoding. """ + boxes = EvalBoxes() + for i in range(10): + boxes.add_boxes(str(i), [EvalBox(), EvalBox(), EvalBox()]) + + recovered = EvalBoxes.deserialize(json.loads(json.dumps(boxes.serialize()))) + self.assertEqual(boxes, recovered) class TestMetricData(unittest.TestCase): def test_serialization(self): - """ test that instance serialization protocol works with json encoding """ + """ Test that instance serialization protocol works with json encoding. """ md = MetricData.random_md() recovered = MetricData.deserialize(json.loads(json.dumps(md.serialize()))) self.assertEqual(md, recovered) @@ -20,7 +43,7 @@ def test_serialization(self): class TestMetricDataList(unittest.TestCase): def test_serialization(self): - """ test that instance serialization protocol works with json encoding """ + """ Test that instance serialization protocol works with json encoding. """ mdl = MetricDataList() for i in range(10): mdl.set('name', 0.1, MetricData.random_md()) @@ -28,25 +51,45 @@ def test_serialization(self): self.assertEqual(mdl, recovered) -class TestEvalBox(unittest.TestCase): +class TestDetectionMetrics(unittest.TestCase): def test_serialization(self): - """ test that instance serialization protocol works with json encoding """ - box = EvalBox() - recovered = EvalBox.deserialize(json.loads(json.dumps(box.serialize()))) - self.assertEqual(box, recovered) - - -class TestEvalBoxes(unittest.TestCase): - - def test_serialization(self): - """ test that instance serialization protocol works with json encoding """ - boxes = EvalBoxes() - for i in range(10): - boxes.add_boxes(str(i), [EvalBox(), EvalBox(), EvalBox()]) - - recovered = EvalBoxes.deserialize(json.loads(json.dumps(boxes.serialize()))) - self.assertEqual(boxes, recovered) + """ Test that instance serialization protocol works with json encoding. """ + + cfg = { + 'class_range': { + 'car': 1.0, + 'truck': 1.0, + 'bus': 1.0, + 'trailer': 1.0, + 'construction_vehicle': 1.0, + 'pedestrian': 1.0, + 'motorcycle': 1.0, + 'bicycle': 1.0, + 'traffic_cone': 1.0, + 'barrier': 1.0 + }, + 'dist_fcn': 'distance', + 'dist_ths': [0.0, 1.0], + 'dist_th_tp': 1.0, + 'min_recall': 0.0, + 'min_precision': 0.0, + 'max_boxes_per_sample': 1, + 'mean_ap_weight': 1.0 + } + detect_config = DetectionConfig.deserialize(cfg) + + metrics = DetectionMetrics(cfg=detect_config) + + for i, name in enumerate(cfg['class_range'].keys()): + metrics.add_label_ap(name, 1.0, float(i)) + for j, tp_name in enumerate(TP_METRICS): + metrics.add_label_tp(name, tp_name, float(j)) + + serialized = json.dumps(metrics.serialize()) + deserialized = DetectionMetrics.deserialize(json.loads(serialized)) + + self.assertEqual(metrics, deserialized) if __name__ == '__main__':