Skip to content

Commit

Permalink
Add deserialize to DetectionMetrics (nutonomy#146)
Browse files Browse the repository at this point in the history
* add deserialize to DetectionMetrics

* Formatting
  • Loading branch information
Alex-nutonomy authored May 13, 2019
1 parent 6761b7f commit 27e4829
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 22 deletions.
38 changes: 37 additions & 1 deletion python-sdk/nuscenes/eval/detection/data_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def __init__(self,

self.class_names = self.class_range.keys()

def __eq__(self, other):
eq = True
for key in self.serialize().keys():
eq = eq and np.array_equal(getattr(self, key), getattr(other, key))
return eq

def serialize(self) -> dict:
""" Serialize instance into json-friendly format. """
return {
Expand Down Expand Up @@ -460,4 +466,34 @@ def serialize(self):
'tp_errors': self.tp_errors,
'tp_scores': self.tp_scores,
'nd_score': self.nd_score,
'eval_time': self.eval_time}
'eval_time': self.eval_time,
'cfg': self.cfg.serialize()}

@classmethod
def deserialize(cls, content):
""" Initialize from serialized dictionary. """

cfg = DetectionConfig.deserialize(content['cfg'])

metrics = cls(cfg=cfg)
metrics.add_runtime(content['eval_time'])

for detection_name, label_aps in content['label_aps'].items():
for dist_th, ap in label_aps.items():
metrics.add_label_ap(detection_name=detection_name, dist_th=float(dist_th), ap=float(ap))

for detection_name, label_tps in content['label_tp_errors'].items():
for metric_name, tp in label_tps.items():
metrics.add_label_tp(detection_name=detection_name, metric_name=metric_name, tp=float(tp))

return metrics

def __eq__(self, other):

eq = True
eq = eq and self._label_aps == other._label_aps
eq = eq and self._label_tp_errors == other._label_tp_errors
eq = eq and self.eval_time == other.eval_time
eq = eq and self.cfg == other.cfg

return eq
85 changes: 64 additions & 21 deletions python-sdk/nuscenes/eval/detection/tests/test_data_classes.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,40 @@
# nuScenes dev-kit.
# Code written by Oscar Beijbom, 2019.
# Code written by Oscar Beijbom and Alex Lang, 2019.
# Licensed under the Creative Commons [see licence.txt]

import json
import unittest

from nuscenes.eval.detection.data_classes import MetricData, EvalBox, EvalBoxes, MetricDataList
from nuscenes.eval.detection.constants import TP_METRICS
from nuscenes.eval.detection.data_classes import MetricData, EvalBox, EvalBoxes, MetricDataList, DetectionConfig, \
DetectionMetrics


class TestEvalBox(unittest.TestCase):

def test_serialization(self):
""" Test that instance serialization protocol works with json encoding. """
box = EvalBox()
recovered = EvalBox.deserialize(json.loads(json.dumps(box.serialize())))
self.assertEqual(box, recovered)


class TestEvalBoxes(unittest.TestCase):

def test_serialization(self):
""" Test that instance serialization protocol works with json encoding. """
boxes = EvalBoxes()
for i in range(10):
boxes.add_boxes(str(i), [EvalBox(), EvalBox(), EvalBox()])

recovered = EvalBoxes.deserialize(json.loads(json.dumps(boxes.serialize())))
self.assertEqual(boxes, recovered)


class TestMetricData(unittest.TestCase):

def test_serialization(self):
""" test that instance serialization protocol works with json encoding """
""" Test that instance serialization protocol works with json encoding. """
md = MetricData.random_md()
recovered = MetricData.deserialize(json.loads(json.dumps(md.serialize())))
self.assertEqual(md, recovered)
Expand All @@ -20,33 +43,53 @@ def test_serialization(self):
class TestMetricDataList(unittest.TestCase):

def test_serialization(self):
""" test that instance serialization protocol works with json encoding """
""" Test that instance serialization protocol works with json encoding. """
mdl = MetricDataList()
for i in range(10):
mdl.set('name', 0.1, MetricData.random_md())
recovered = MetricDataList.deserialize(json.loads(json.dumps(mdl.serialize())))
self.assertEqual(mdl, recovered)


class TestEvalBox(unittest.TestCase):
class TestDetectionMetrics(unittest.TestCase):

def test_serialization(self):
""" test that instance serialization protocol works with json encoding """
box = EvalBox()
recovered = EvalBox.deserialize(json.loads(json.dumps(box.serialize())))
self.assertEqual(box, recovered)


class TestEvalBoxes(unittest.TestCase):

def test_serialization(self):
""" test that instance serialization protocol works with json encoding """
boxes = EvalBoxes()
for i in range(10):
boxes.add_boxes(str(i), [EvalBox(), EvalBox(), EvalBox()])

recovered = EvalBoxes.deserialize(json.loads(json.dumps(boxes.serialize())))
self.assertEqual(boxes, recovered)
""" Test that instance serialization protocol works with json encoding. """

cfg = {
'class_range': {
'car': 1.0,
'truck': 1.0,
'bus': 1.0,
'trailer': 1.0,
'construction_vehicle': 1.0,
'pedestrian': 1.0,
'motorcycle': 1.0,
'bicycle': 1.0,
'traffic_cone': 1.0,
'barrier': 1.0
},
'dist_fcn': 'distance',
'dist_ths': [0.0, 1.0],
'dist_th_tp': 1.0,
'min_recall': 0.0,
'min_precision': 0.0,
'max_boxes_per_sample': 1,
'mean_ap_weight': 1.0
}
detect_config = DetectionConfig.deserialize(cfg)

metrics = DetectionMetrics(cfg=detect_config)

for i, name in enumerate(cfg['class_range'].keys()):
metrics.add_label_ap(name, 1.0, float(i))
for j, tp_name in enumerate(TP_METRICS):
metrics.add_label_tp(name, tp_name, float(j))

serialized = json.dumps(metrics.serialize())
deserialized = DetectionMetrics.deserialize(json.loads(serialized))

self.assertEqual(metrics, deserialized)


if __name__ == '__main__':
Expand Down

0 comments on commit 27e4829

Please sign in to comment.