Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoli committed Sep 16, 2021
1 parent cf2f128 commit 68e06c5
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 42 deletions.
27 changes: 27 additions & 0 deletions python-sdk/nuscenes/eval/tracking/configs/tracking_nips_2019.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,31 @@
{
"tracking_names": [
"bicycle",
"bus",
"car",
"motorcycle",
"pedestrian",
"trailer",
"truck"
],
"pretty_tracking_names": {
"bicycle": "Bicycle",
"bus": "Bus",
"car": "Car",
"motorcycle": "Motorcycle",
"pedestrian": "Pedestrian",
"trailer": "Trailer",
"truck": "Truck"
},
"tracking_colors": {
"bicycle": "C9",
"bus": "C2",
"car": "C0",
"motorcycle": "C6",
"pedestrian": "C5",
"trailer": "C3",
"truck": "C1"
},
"class_range": {
"car": 50,
"truck": 50,
Expand Down
21 changes: 0 additions & 21 deletions python-sdk/nuscenes/eval/tracking/constants.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,12 @@
# nuScenes dev-kit.
# Code written by Holger Caesar, Caglayan Dicle and Oscar Beijbom, 2019.

TRACKING_NAMES = ['bicycle', 'bus', 'car', 'motorcycle', 'pedestrian', 'trailer', 'truck']

AMOT_METRICS = ['amota', 'amotp']
INTERNAL_METRICS = ['recall', 'motar', 'gt']
LEGACY_METRICS = ['mota', 'motp', 'mt', 'ml', 'faf', 'tp', 'fp', 'fn', 'ids', 'frag', 'tid', 'lgd']
TRACKING_METRICS = [*AMOT_METRICS, *INTERNAL_METRICS, *LEGACY_METRICS]

PRETTY_TRACKING_NAMES = {
'bicycle': 'Bicycle',
'bus': 'Bus',
'car': 'Car',
'motorcycle': 'Motorcycle',
'pedestrian': 'Pedestrian',
'trailer': 'Trailer',
'truck': 'Truck'
}

TRACKING_COLORS = {
'bicycle': 'C9', # Differs from detection.
'bus': 'C2',
'car': 'C0',
'motorcycle': 'C6',
'pedestrian': 'C5',
'trailer': 'C3',
'truck': 'C1'
}

# Define mapping for metrics averaged over classes.
AVG_METRIC_MAP = { # Mapping from average metric name to individual per-threshold metric name.
'amota': 'motar',
Expand Down
23 changes: 18 additions & 5 deletions python-sdk/nuscenes/eval/tracking/data_classes.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
# nuScenes dev-kit.
# Code written by Holger Caesar, Caglayan Dicle and Oscar Beijbom, 2019.

from typing import Dict, Tuple, Any
from typing import Any, Dict, List, Tuple

import numpy as np

from nuscenes.eval.common.data_classes import MetricData, EvalBox
from nuscenes.eval.common.utils import center_distance
from nuscenes.eval.tracking.constants import TRACKING_NAMES, TRACKING_METRICS, AMOT_METRICS
from nuscenes.eval.tracking.constants import TRACKING_METRICS, AMOT_METRICS


class TrackingConfig:
""" Data class that specifies the tracking evaluation settings. """

def __init__(self,
tracking_names: List[str],
pretty_tracking_names: Dict[str, str],
tracking_colors: Dict[str, str],
class_range: Dict[str, int],
dist_fcn: str,
dist_th_tp: float,
Expand All @@ -22,8 +25,12 @@ def __init__(self,
metric_worst: Dict[str, float],
num_thresholds: int):

assert set(class_range.keys()) == set(TRACKING_NAMES), "Class count mismatch."

assert set(class_range.keys()) == set(tracking_names), "Class count mismatch."
global TRACKING_NAMES
TRACKING_NAMES = tracking_names
self.tracking_names = tracking_names
self.pretty_tracking_names = pretty_tracking_names
self.tracking_colors = tracking_colors
self.class_range = class_range
self.dist_fcn = dist_fcn
self.dist_th_tp = dist_th_tp
Expand All @@ -45,6 +52,9 @@ def __eq__(self, other):
def serialize(self) -> dict:
""" Serialize instance into json-friendly format. """
return {
'tracking_names': self.tracking_names,
'pretty_tracking_names': self.pretty_tracking_names,
'tracking_colors': self.tracking_colors,
'class_range': self.class_range,
'dist_fcn': self.dist_fcn,
'dist_th_tp': self.dist_th_tp,
Expand All @@ -57,7 +67,10 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, content: dict):
""" Initialize from serialized dictionary. """
return cls(content['class_range'],
return cls(content['tracking_names'],
content['pretty_tracking_names'],
content['tracking_colors'],
content['class_range'],
content['dist_fcn'],
content['dist_th_tp'],
content['min_recall'],
Expand Down
7 changes: 3 additions & 4 deletions python-sdk/nuscenes/eval/tracking/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,11 @@ def savepath(name):
return os.path.join(self.plot_dir, name + '.pdf')

# Plot a summary.
summary_plot(md_list, min_recall=self.cfg.min_recall, savepath=savepath('summary'))
summary_plot(self.cfg, md_list, savepath=savepath('summary'))

# For each metric, plot all the classes in one diagram.
for metric_name in LEGACY_METRICS:
recall_metric_curve(md_list, metric_name,
self.cfg.min_recall, savepath=savepath('%s' % metric_name))
recall_metric_curve(self.cfg, md_list, metric_name, savepath=savepath('%s' % metric_name))

def main(self, render_curves: bool = True) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -230,7 +229,7 @@ def main(self, render_curves: bool = True) -> Dict[str, Any]:
# Settings.
parser = argparse.ArgumentParser(description='Evaluate nuScenes tracking results.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('result_path', type=str, help='The submission as a JSON file.')
parser.add_argument('--result_path', type=str, help='The submission as a JSON file.')
parser.add_argument('--output_dir', type=str, default='~/nuscenes-metrics',
help='Folder to store result metrics, graphs and example visualizations.')
parser.add_argument('--eval_set', type=str, default='val',
Expand Down
21 changes: 11 additions & 10 deletions python-sdk/nuscenes/eval/tracking/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@
from pyquaternion import Quaternion

from nuscenes.eval.common.render import setup_axis
from nuscenes.eval.tracking.constants import TRACKING_COLORS, PRETTY_TRACKING_NAMES
from nuscenes.eval.tracking.data_classes import TrackingBox, TrackingMetricDataList
from nuscenes.utils.data_classes import Box
from nuscenes.eval.tracking.data_classes import TrackingConfig

Axis = Any


def summary_plot(md_list: TrackingMetricDataList,
min_recall: float,
def summary_plot(cfg: TrackingConfig,
md_list: TrackingMetricDataList,
ncols: int = 2,
savepath: str = None) -> None:
"""
Creates a summary plot with which includes all traditional metrics for each class.
:param cfg: A TrackingConfig object.
:param md_list: TrackingMetricDataList instance.
:param min_recall: Minimum recall value.
:param ncols: How many columns the resulting plot should have.
:param savepath: If given, saves the the rendering here instead of displaying.
"""
Expand All @@ -38,7 +38,7 @@ def summary_plot(md_list: TrackingMetricDataList,
for ind, metric_name in enumerate(rel_metrics):
row = ind // ncols
col = np.mod(ind, ncols)
recall_metric_curve(md_list, metric_name, min_recall, ax=axes[row, col])
recall_metric_curve(cfg, md_list, metric_name, ax=axes[row, col])

# Set layout with little white space and save to disk.
plt.tight_layout()
Expand All @@ -47,19 +47,20 @@ def summary_plot(md_list: TrackingMetricDataList,
plt.close()


def recall_metric_curve(md_list: TrackingMetricDataList,
def recall_metric_curve(cfg: TrackingConfig,
md_list: TrackingMetricDataList,
metric_name: str,
min_recall: float,
savepath: str = None,
ax: Axis = None) -> None:
"""
Plot the recall versus metric curve for the given metric.
:param cfg: A TrackingConfig object.
:param md_list: TrackingMetricDataList instance.
:param metric_name: The name of the metric to plot.
:param min_recall: Minimum recall value.
:param savepath: If given, saves the the rendering here instead of displaying.
:param ax: Axes onto which to render or None to create a new axis.
"""
min_recall = cfg.min_recall # Minimum recall value from config.
# Setup plot.
if ax is None:
_, ax = plt.subplots(1, 1, figsize=(7.5, 5))
Expand All @@ -84,8 +85,8 @@ def recall_metric_curve(md_list: TrackingMetricDataList,

ax.plot(recalls,
values,
label='%s' % PRETTY_TRACKING_NAMES[tracking_name],
color=TRACKING_COLORS[tracking_name])
label='%s' % cfg.pretty_tracking_names[tracking_name],
color=cfg.tracking_colors[tracking_name])

# Scale count statistics and FAF logarithmically.
if metric_name in ['mt', 'ml', 'faf', 'tp', 'fp', 'fn', 'ids', 'frag']:
Expand Down
5 changes: 3 additions & 2 deletions python-sdk/nuscenes/eval/tracking/tests/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from nuscenes import NuScenes
from nuscenes.eval.common.config import config_factory
from nuscenes.eval.tracking.constants import TRACKING_NAMES
from nuscenes.eval.tracking.evaluate import TrackingEval
from nuscenes.eval.tracking.utils import category_to_tracking_name
from nuscenes.utils.splits import create_splits_scenes
Expand All @@ -41,10 +40,12 @@ def _mock_submission(nusc: NuScenes,
:param split: Dataset split to use.
:param add_errors: Whether to use GT or add errors to it.
"""
# Get config.
cfg = config_factory('tracking_nips_2019')

def random_class(category_name: str, _add_errors: bool = False) -> Optional[str]:
# Alter 10% of the valid labels.
class_names = sorted(TRACKING_NAMES)
class_names = sorted(cfg.tracking_names)
tmp = category_to_tracking_name(category_name)

if tmp is None:
Expand Down

0 comments on commit 68e06c5

Please sign in to comment.