Skip to content
47 changes: 46 additions & 1 deletion python-sdk/nuscenes/eval/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,57 @@

import numpy as np
from pyquaternion import Quaternion

from shapely import affinity
from shapely.geometry import Polygon
from nuscenes.eval.common.data_classes import EvalBox
from nuscenes.utils.data_classes import Box

DetectionBox = Any # Workaround as direct imports lead to cyclic dependencies.

def create_2d_polygon_from_box(bbox: EvalBox) -> Polygon:
"""
Convert an EvalBox into a 2D Polygon
:param bbox: An EvalBox describing center, rotation and size.
:return: A 2D Polygon describing the xy vertices.
"""
l = bbox.size[0]
w = bbox.size[1]
poly_veh = Polygon(((0.5*l,0.5*w),(-0.5*l,0.5*w),(-0.5*l,-0.5*w),(0.5*l,-0.5*w),(0.5*l,0.5*w)))
poly_rot = affinity.rotate(poly_veh,quaternion_yaw(Quaternion(bbox.rotation)),use_radians=True)
poly_glob = affinity.translate(poly_rot,bbox.translation[0],bbox.translation[1])
return poly_glob

def bev_iou(gt_poly: Polygon, pred_poly: Polygon) -> float:
"""
Birds Eye View IOU percentage between two input polygons (xy only).
:param gt_poly: GT annotation sample.
:param pred_poly: Predicted sample.
:return: IOU.
"""
intersection = gt_poly.intersection(pred_poly).area
bev_iou = intersection/(gt_poly.area + pred_poly.area - intersection)

# Guard against machine precision (i.e. when dealing with perfect overlap)
bev_iou = min(bev_iou,1.0)
return bev_iou

def bev_iou_complement(gt_box: EvalBox, pred_box: EvalBox) -> float:
"""
1 - BEV_IOU percentage between two input boxes (xy only).
:param gt_box: GT annotation sample.
:param pred_box: Predicted sample.
:return: 1 - IOU.
"""
# Do a cheaper first pass before calculating IOU i.e. check if the circles that enclose the two
# boxes overlap
gt_radius = np.linalg.norm(0.5*np.array([gt_box.size[0],gt_box.size[1]]))
pred_radius = np.linalg.norm(0.5*np.array([pred_box.size[0],pred_box.size[1]]))
if (center_distance(gt_box,pred_box) >= pred_radius + gt_radius):
bev_iou_complement = 1.0
else:
bev_iou_complement = 1.0 - bev_iou(create_2d_polygon_from_box(gt_box),
create_2d_polygon_from_box(pred_box))
return bev_iou_complement

def center_distance(gt_box: EvalBox, pred_box: EvalBox) -> float:
"""
Expand Down
4 changes: 3 additions & 1 deletion python-sdk/nuscenes/eval/detection/data_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np

from nuscenes.eval.common.data_classes import MetricData, EvalBox
from nuscenes.eval.common.utils import center_distance
from nuscenes.eval.common.utils import center_distance, bev_iou_complement
from nuscenes.eval.detection.constants import DETECTION_NAMES, ATTRIBUTE_NAMES, TP_METRICS


Expand Down Expand Up @@ -74,6 +74,8 @@ def dist_fcn_callable(self):
""" Return the distance function corresponding to the dist_fcn string. """
if self.dist_fcn == 'center_distance':
return center_distance
elif self.dist_fcn == "bev_iou_complement":
return bev_iou_complement
else:
raise Exception('Error: Unknown distance function %s!' % self.dist_fcn)

Expand Down
11 changes: 11 additions & 0 deletions python-sdk/nuscenes/eval/detection/tests/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,5 +149,16 @@ def test_delta(self, eval_split, mock__get_custom_splits_file_path):
# 10. Score = 0.19449091580477748. Changed to use v1.0 mini_val split, and the equal mini_custom_val split.
self.assertAlmostEqual(metrics.nd_score, 0.19449091580477748)

# Evaluate again but use the bev_iou_complement distance function
# 1. Score = 0.16651633528966858. Measured on forked repo sbarkby/nuscenes-devkit April 22nd 2024.
cfg.dist_fcn = "bev_iou_complement"
cfg.dist_ths = [0,0.999999]
cfg.dist_th_tp = 0.999999

nusc_eval = DetectionEval(nusc, cfg, self.res_mockup, eval_set=eval_split, output_dir=self.res_eval_folder,
verbose=False)
metrics, md_list = nusc_eval.evaluate()
self.assertAlmostEqual(metrics.nd_score, 0.16651633528966858)

if __name__ == '__main__':
unittest.main()
44 changes: 42 additions & 2 deletions python-sdk/nuscenes/eval/detection/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from numpy.testing import assert_array_almost_equal
from pyquaternion import Quaternion

from nuscenes.eval.common.utils import attr_acc, scale_iou, yaw_diff, angle_diff, center_distance, velocity_l2, \
cummean
from nuscenes.eval.common.utils import attr_acc, scale_iou, yaw_diff, angle_diff, create_2d_polygon_from_box, bev_iou, \
bev_iou_complement, center_distance, velocity_l2, cummean
from nuscenes.eval.detection.data_classes import DetectionBox


Expand Down Expand Up @@ -128,6 +128,46 @@ def rad(x):
period = 360
self.assertAlmostEqual(rad(180), abs(angle_diff(rad(a), rad(b), rad(period))))

def test_create_2d_polygon_from_box(self):
# Create a box rotated 30 degrees and offset of (2,4), check against hand calculated math
poly = create_2d_polygon_from_box(DetectionBox(rotation=(0.96592582628,0,0,0.2588190451),
translation=(2, 4, 1), size=(3,1,2)))
self.assertAlmostEqual(poly.exterior.coords[0][0],3.04903810568)
self.assertAlmostEqual(poly.exterior.coords[0][1],5.18301270189)
self.assertAlmostEqual(poly.exterior.coords[1][0],0.45096189432)
self.assertAlmostEqual(poly.exterior.coords[1][1],3.6830127019)
self.assertAlmostEqual(poly.exterior.coords[2][0],0.95096189432)
self.assertAlmostEqual(poly.exterior.coords[2][1],2.81698729811)
self.assertAlmostEqual(poly.exterior.coords[3][0],3.54903810568)
self.assertAlmostEqual(poly.exterior.coords[3][1],4.3169872981)

def test_bev_iou(self):
# Two boxes specified, no overlap
sa = create_2d_polygon_from_box(DetectionBox(translation=(1.0, 0.0, 1.0),
size=(2,1,1)))
sr = create_2d_polygon_from_box(DetectionBox(translation=(3.5, 0.0, 1.0),
size=(3,1,2)))
self.assertAlmostEqual(bev_iou(sa, sr), 0.0)

# Two boxes specified, one rotated by 90 degrees in z axis, should attain 1m^2 overlap
sa = create_2d_polygon_from_box(DetectionBox(rotation=(0,0,0,0), translation=(1.0, 0.5, 2.0),
size=(2,1,1)))
sr = create_2d_polygon_from_box(DetectionBox(rotation=(0.70710678118,0,0,0.70710678118),
translation=(0.5, 1.5, 1), size=(3,1,2)))
self.assertAlmostEqual(bev_iou(sa, sr), 0.25)

def test_bev_iou_complement(self):
# Two boxes specified, no overlap
sa = DetectionBox(translation=(1.0, 0.0, 1.0), size=(2,1,1))
sr = DetectionBox(translation=(3.5, 0.0, 1.0), size=(3,1,2))
self.assertAlmostEqual(bev_iou_complement(sa, sr), 1.0)

# Two boxes specified, one rotated by 90 degrees in z axis, should attain 1m^2 overlap
sa = DetectionBox(rotation=(0,0,0,0), translation=(1.0, 0.5, 2.0), size=(2,1,1))
sr = DetectionBox(rotation=(0.70710678118,0,0,0.70710678118),
translation=(0.5, 1.5, 1), size=(3,1,2))
self.assertAlmostEqual(bev_iou_complement(sa, sr), 0.75)

def test_center_distance(self):
"""Test for center_distance()."""

Expand Down
11 changes: 9 additions & 2 deletions python-sdk/nuscenes/eval/tracking/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
except ModuleNotFoundError:
raise unittest.SkipTest('Skipping test as pandas was not found!')

from nuscenes.eval.common.utils import bev_iou_complement
from nuscenes.eval.tracking.constants import MOT_METRIC_MAP, TRACKING_METRICS
from nuscenes.eval.tracking.data_classes import TrackingBox, TrackingMetricData
from nuscenes.eval.tracking.mot import MOTAccumulatorCustom
Expand Down Expand Up @@ -257,13 +258,19 @@ def accumulate_threshold(self, threshold: float = None) -> Tuple[pandas.DataFram

# Calculate distances.
# Note that the distance function is hard-coded to achieve significant speedups via vectorization.
assert self.dist_fcn.__name__ == 'center_distance'
if len(frame_gt) == 0 or len(frame_pred) == 0:
distances = np.ones((0, 0))
else:
elif self.dist_fcn.__name__ == 'center_distance':
gt_boxes = np.array([b.translation[:2] for b in frame_gt])
pred_boxes = np.array([b.translation[:2] for b in frame_pred])
distances = sklearn.metrics.pairwise.euclidean_distances(gt_boxes, pred_boxes)
elif self.dist_fcn.__name__ == 'bev_iou_complement':
distances = np.zeros((len(frame_gt),len(frame_pred)))
for i in range(len(frame_gt)):
for j in range(len(frame_pred)):
distances[i,j] = bev_iou_complement(frame_gt[i],frame_pred[j])
else:
raise Exception('Error: Unknown distance function %s!' % self.dist_fcn.__name__)

# Distances that are larger than the threshold won't be associated.
assert len(distances) == 0 or not np.all(np.isnan(distances))
Expand Down
4 changes: 3 additions & 1 deletion python-sdk/nuscenes/eval/tracking/data_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

from nuscenes.eval.common.data_classes import MetricData, EvalBox
from nuscenes.eval.common.utils import center_distance
from nuscenes.eval.common.utils import center_distance, bev_iou_complement
from nuscenes.eval.tracking.constants import TRACKING_METRICS, AMOT_METRICS


Expand Down Expand Up @@ -86,6 +86,8 @@ def dist_fcn_callable(self):
""" Return the distance function corresponding to the dist_fcn string. """
if self.dist_fcn == 'center_distance':
return center_distance
elif self.dist_fcn == "bev_iou_complement":
return bev_iou_complement
else:
raise Exception('Error: Unknown distance function %s!' % self.dist_fcn)

Expand Down
53 changes: 50 additions & 3 deletions python-sdk/nuscenes/eval/tracking/tests/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ def random_id(instance_token: str, _add_errors: bool = False) -> str:
def basic_test(self,
eval_set: str = 'mini_val',
add_errors: bool = False,
render_curves: bool = False) -> Dict[str, Any]:
render_curves: bool = False,
dist_fcn: str = '',
dist_th_tp: float = 0.0) -> Dict[str, Any]:
"""
Run the evaluation with fixed randomness on the specified subset, with or without introducing errors in the
submission.
Expand All @@ -174,6 +176,11 @@ def basic_test(self,
json.dump(mock, f, indent=2)

cfg = config_factory('tracking_nips_2019')

# Update dist fcn and threshold with those specified
cfg.dist_fcn = dist_fcn
cfg.dist_th_tp = dist_th_tp

nusc_eval = TrackingEval(cfg, self.res_mockup, eval_set=eval_set, output_dir=self.res_eval_folder,
nusc_version=version, nusc_dataroot=os.environ['NUSCENES'], verbose=False)
metrics = nusc_eval.main(render_curves=render_curves)
Expand All @@ -192,7 +199,8 @@ def test_delta_mock(self,
:param render_curves: Whether to render stats curves to disk.
"""
# Run the evaluation with errors.
metrics = self.basic_test(eval_set, add_errors=True, render_curves=render_curves)
metrics = self.basic_test(eval_set, add_errors=True, render_curves=render_curves,
dist_fcn='center_distance', dist_th_tp=2.0)

# Compare metrics to known solution.
if eval_set == 'mini_val':
Expand All @@ -204,6 +212,20 @@ def test_delta_mock(self,
else:
print('Skipping checks due to choice of custom eval_set: %s' % eval_set)

# Run again with the alternative bev_iou_complement dist_fcn
metrics = self.basic_test(eval_set, add_errors=True, render_curves=render_curves,
dist_fcn='bev_iou_complement', dist_th_tp=0.999999)

# Compare metrics to known solution.
if eval_set == 'mini_val':
self.assertAlmostEqual(metrics['amota'], 0.231839679131956)
self.assertAlmostEqual(metrics['amotp'], 1.3629342647309446)
self.assertAlmostEqual(metrics['motar'], 0.27918315466340504)
self.assertAlmostEqual(metrics['mota'], 0.22922560056448252)
self.assertAlmostEqual(metrics['motp'], 0.7541595548820258)
else:
print('Skipping checks due to choice of custom eval_set: %s' % eval_set)

@parameterized.expand([
('mini_val',),
('mini_custom_train',)
Expand All @@ -224,7 +246,8 @@ def test_delta_gt(self,
mock__get_custom_splits_file_path.return_value = self.splits_file_mockup

# Run the evaluation without errors.
metrics = self.basic_test(eval_set, add_errors=False, render_curves=render_curves)
metrics = self.basic_test(eval_set, add_errors=False, render_curves=render_curves,
dist_fcn='center_distance', dist_th_tp=2.0)

# Compare metrics to known solution. Do not check:
# - MT/TP (hard to figure out here).
Expand All @@ -247,6 +270,30 @@ def test_delta_gt(self,
else:
print('Skipping checks due to choice of custom eval_set: %s' % eval_set)

# Run again with the alternative bev_iou_complement dist_fcn (and a very precise threshold)
metrics = self.basic_test(eval_set, add_errors=False, render_curves=render_curves,
dist_fcn='bev_iou_complement', dist_th_tp=1e-6)

# Compare metrics to known solution. Do not check:
# - MT/TP (hard to figure out here).
# - AMOTA/AMOTP (unachieved recall values lead to hard unintuitive results).
if eval_set in ['mini_val', 'mini_custom_train']:
self.assertAlmostEqual(metrics['amota'], 1.0)
self.assertAlmostEqual(metrics['amotp'], 0.0, delta=1e-5)
self.assertAlmostEqual(metrics['motar'], 1.0)
self.assertAlmostEqual(metrics['recall'], 1.0)
self.assertAlmostEqual(metrics['mota'], 1.0)
self.assertAlmostEqual(metrics['motp'], 0.0, delta=1e-5)
self.assertAlmostEqual(metrics['faf'], 0.0)
self.assertAlmostEqual(metrics['ml'], 0.0)
self.assertAlmostEqual(metrics['fp'], 0.0)
self.assertAlmostEqual(metrics['fn'], 0.0)
self.assertAlmostEqual(metrics['ids'], 0.0)
self.assertAlmostEqual(metrics['frag'], 0.0)
self.assertAlmostEqual(metrics['tid'], 0.0)
self.assertAlmostEqual(metrics['lgd'], 0.0)
else:
print('Skipping checks due to choice of custom eval_set: %s' % eval_set)

if __name__ == '__main__':
unittest.main()