From 27e4829327933c6a608d0ca22580354cd9598e9b Mon Sep 17 00:00:00 2001
From: Alex Lang
Date: Mon, 13 May 2019 15:38:16 +0800
Subject: [PATCH] Add deserialize to DetectionMetrics (#146)
* add deserialize to DetectionMetrics
* Formatting
---
.../nuscenes/eval/detection/data_classes.py | 38 ++++++++-
.../eval/detection/tests/test_data_classes.py | 85 ++++++++++++++-----
2 files changed, 101 insertions(+), 22 deletions(-)
diff --git a/python-sdk/nuscenes/eval/detection/data_classes.py b/python-sdk/nuscenes/eval/detection/data_classes.py
index f5054903b..760a82f51 100644
--- a/python-sdk/nuscenes/eval/detection/data_classes.py
+++ b/python-sdk/nuscenes/eval/detection/data_classes.py
@@ -38,6 +38,12 @@ def __init__(self,
self.class_names = self.class_range.keys()
+ def __eq__(self, other):
+ eq = True
+ for key in self.serialize().keys():
+ eq = eq and np.array_equal(getattr(self, key), getattr(other, key))
+ return eq
+
def serialize(self) -> dict:
""" Serialize instance into json-friendly format. """
return {
@@ -460,4 +466,34 @@ def serialize(self):
'tp_errors': self.tp_errors,
'tp_scores': self.tp_scores,
'nd_score': self.nd_score,
- 'eval_time': self.eval_time}
+ 'eval_time': self.eval_time,
+ 'cfg': self.cfg.serialize()}
+
+ @classmethod
+ def deserialize(cls, content):
+ """ Initialize from serialized dictionary. """
+
+ cfg = DetectionConfig.deserialize(content['cfg'])
+
+ metrics = cls(cfg=cfg)
+ metrics.add_runtime(content['eval_time'])
+
+ for detection_name, label_aps in content['label_aps'].items():
+ for dist_th, ap in label_aps.items():
+ metrics.add_label_ap(detection_name=detection_name, dist_th=float(dist_th), ap=float(ap))
+
+ for detection_name, label_tps in content['label_tp_errors'].items():
+ for metric_name, tp in label_tps.items():
+ metrics.add_label_tp(detection_name=detection_name, metric_name=metric_name, tp=float(tp))
+
+ return metrics
+
+ def __eq__(self, other):
+
+ eq = True
+ eq = eq and self._label_aps == other._label_aps
+ eq = eq and self._label_tp_errors == other._label_tp_errors
+ eq = eq and self.eval_time == other.eval_time
+ eq = eq and self.cfg == other.cfg
+
+ return eq
diff --git a/python-sdk/nuscenes/eval/detection/tests/test_data_classes.py b/python-sdk/nuscenes/eval/detection/tests/test_data_classes.py
index efbf1ba4b..f09a453c4 100644
--- a/python-sdk/nuscenes/eval/detection/tests/test_data_classes.py
+++ b/python-sdk/nuscenes/eval/detection/tests/test_data_classes.py
@@ -1,17 +1,40 @@
# nuScenes dev-kit.
-# Code written by Oscar Beijbom, 2019.
+# Code written by Oscar Beijbom and Alex Lang, 2019.
# Licensed under the Creative Commons [see licence.txt]
import json
import unittest
-from nuscenes.eval.detection.data_classes import MetricData, EvalBox, EvalBoxes, MetricDataList
+from nuscenes.eval.detection.constants import TP_METRICS
+from nuscenes.eval.detection.data_classes import MetricData, EvalBox, EvalBoxes, MetricDataList, DetectionConfig, \
+ DetectionMetrics
+
+
+class TestEvalBox(unittest.TestCase):
+
+ def test_serialization(self):
+ """ Test that instance serialization protocol works with json encoding. """
+ box = EvalBox()
+ recovered = EvalBox.deserialize(json.loads(json.dumps(box.serialize())))
+ self.assertEqual(box, recovered)
+
+
+class TestEvalBoxes(unittest.TestCase):
+
+ def test_serialization(self):
+ """ Test that instance serialization protocol works with json encoding. """
+ boxes = EvalBoxes()
+ for i in range(10):
+ boxes.add_boxes(str(i), [EvalBox(), EvalBox(), EvalBox()])
+
+ recovered = EvalBoxes.deserialize(json.loads(json.dumps(boxes.serialize())))
+ self.assertEqual(boxes, recovered)
class TestMetricData(unittest.TestCase):
def test_serialization(self):
- """ test that instance serialization protocol works with json encoding """
+ """ Test that instance serialization protocol works with json encoding. """
md = MetricData.random_md()
recovered = MetricData.deserialize(json.loads(json.dumps(md.serialize())))
self.assertEqual(md, recovered)
@@ -20,7 +43,7 @@ def test_serialization(self):
class TestMetricDataList(unittest.TestCase):
def test_serialization(self):
- """ test that instance serialization protocol works with json encoding """
+ """ Test that instance serialization protocol works with json encoding. """
mdl = MetricDataList()
for i in range(10):
mdl.set('name', 0.1, MetricData.random_md())
@@ -28,25 +51,45 @@ def test_serialization(self):
self.assertEqual(mdl, recovered)
-class TestEvalBox(unittest.TestCase):
+class TestDetectionMetrics(unittest.TestCase):
def test_serialization(self):
- """ test that instance serialization protocol works with json encoding """
- box = EvalBox()
- recovered = EvalBox.deserialize(json.loads(json.dumps(box.serialize())))
- self.assertEqual(box, recovered)
-
-
-class TestEvalBoxes(unittest.TestCase):
-
- def test_serialization(self):
- """ test that instance serialization protocol works with json encoding """
- boxes = EvalBoxes()
- for i in range(10):
- boxes.add_boxes(str(i), [EvalBox(), EvalBox(), EvalBox()])
-
- recovered = EvalBoxes.deserialize(json.loads(json.dumps(boxes.serialize())))
- self.assertEqual(boxes, recovered)
+ """ Test that instance serialization protocol works with json encoding. """
+
+ cfg = {
+ 'class_range': {
+ 'car': 1.0,
+ 'truck': 1.0,
+ 'bus': 1.0,
+ 'trailer': 1.0,
+ 'construction_vehicle': 1.0,
+ 'pedestrian': 1.0,
+ 'motorcycle': 1.0,
+ 'bicycle': 1.0,
+ 'traffic_cone': 1.0,
+ 'barrier': 1.0
+ },
+ 'dist_fcn': 'distance',
+ 'dist_ths': [0.0, 1.0],
+ 'dist_th_tp': 1.0,
+ 'min_recall': 0.0,
+ 'min_precision': 0.0,
+ 'max_boxes_per_sample': 1,
+ 'mean_ap_weight': 1.0
+ }
+ detect_config = DetectionConfig.deserialize(cfg)
+
+ metrics = DetectionMetrics(cfg=detect_config)
+
+ for i, name in enumerate(cfg['class_range'].keys()):
+ metrics.add_label_ap(name, 1.0, float(i))
+ for j, tp_name in enumerate(TP_METRICS):
+ metrics.add_label_tp(name, tp_name, float(j))
+
+ serialized = json.dumps(metrics.serialize())
+ deserialized = DetectionMetrics.deserialize(json.loads(serialized))
+
+ self.assertEqual(metrics, deserialized)
if __name__ == '__main__':