Skip to content

Commit 5bf944c

Browse files
authored
[Feature] Implement portable eval script for server (#24)
* Implement portable eval script for server * lint format
1 parent 1adf462 commit 5bf944c

File tree

1 file changed

+371
-0
lines changed

1 file changed

+371
-0
lines changed

tools/eval_script_portable.py

Lines changed: 371 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,371 @@
1+
# Copyright (c) OpenRobotLab. All rights reserved.
2+
import argparse
3+
from typing import Union
4+
5+
import mmengine
6+
import numpy as np
7+
import torch
8+
from mmengine.logging import print_log
9+
from pytorch3d.ops import box3d_overlap
10+
from pytorch3d.transforms import euler_angles_to_matrix
11+
from terminaltables import AsciiTable
12+
13+
14+
def rotation_3d_in_euler(points, angles, return_mat=False, clockwise=False):
15+
"""Rotate points by angles according to axis.
16+
17+
Args:
18+
points (np.ndarray | torch.Tensor | list | tuple ):
19+
Points of shape (N, M, 3).
20+
angles (np.ndarray | torch.Tensor | list | tuple):
21+
Vector of angles in shape (N, 3)
22+
return_mat: Whether or not return the rotation matrix (transposed).
23+
Defaults to False.
24+
clockwise: Whether the rotation is clockwise. Defaults to False.
25+
26+
Raises:
27+
ValueError: when the axis is not in range [0, 1, 2], it will
28+
raise value error.
29+
30+
Returns:
31+
(torch.Tensor | np.ndarray): Rotated points in shape (N, M, 3).
32+
"""
33+
batch_free = len(points.shape) == 2
34+
if batch_free:
35+
points = points[None]
36+
37+
if len(angles.shape) == 1:
38+
angles = angles.expand(points.shape[:1] + (3, ))
39+
# angles = torch.full(points.shape[:1], angles)
40+
41+
assert len(points.shape) == 3 and len(angles.shape) == 2 \
42+
and points.shape[0] == angles.shape[0], f'Incorrect shape of points ' \
43+
f'angles: {points.shape}, {angles.shape}'
44+
45+
assert points.shape[-1] in [2, 3], \
46+
f'Points size should be 2 or 3 instead of {points.shape[-1]}'
47+
48+
rot_mat_T = euler_angles_to_matrix(angles, 'ZXY') # N, 3,3
49+
rot_mat_T = rot_mat_T.transpose(-2, -1)
50+
51+
if clockwise:
52+
raise NotImplementedError('clockwise')
53+
54+
if points.shape[0] == 0:
55+
points_new = points
56+
else:
57+
points_new = torch.bmm(points, rot_mat_T)
58+
59+
if batch_free:
60+
points_new = points_new.squeeze(0)
61+
62+
if return_mat:
63+
if batch_free:
64+
rot_mat_T = rot_mat_T.squeeze(0)
65+
return points_new, rot_mat_T
66+
else:
67+
return points_new
68+
69+
70+
class EulerDepthInstance3DBoxes:
71+
"""3D boxes of instances in Depth coordinates.
72+
73+
We keep the "Depth" coordinate system definition in MMDet3D just for
74+
clarification of the points coordinates and the flipping augmentation.
75+
76+
Coordinates in Depth:
77+
78+
.. code-block:: none
79+
80+
up z y front (alpha=0.5*pi)
81+
^ ^
82+
| /
83+
| /
84+
0 ------> x right (alpha=0)
85+
86+
The relative coordinate of bottom center in a Depth box is (0.5, 0.5, 0),
87+
and the yaw is around the z axis, thus the rotation axis=2.
88+
The yaw is 0 at the positive direction of x axis, and decreases from
89+
the positive direction of x to the positive direction of y.
90+
Also note that rotation of DepthInstance3DBoxes is counterclockwise,
91+
which is reverse to the definition of the yaw angle (clockwise).
92+
93+
Attributes:
94+
tensor (torch.Tensor): Float matrix of N x box_dim.
95+
box_dim (int): Integer indicates the dimension of a box
96+
Each row is (x, y, z, x_size, y_size, z_size, alpha, beta, gamma).
97+
with_yaw (bool): If True, the value of yaw will be set to 0 as minmax
98+
boxes.
99+
"""
100+
101+
def __init__(self,
102+
tensor,
103+
box_dim=9,
104+
with_yaw=True,
105+
origin=(0.5, 0.5, 0.5)):
106+
107+
if isinstance(tensor, torch.Tensor):
108+
device = tensor.device
109+
else:
110+
device = torch.device('cpu')
111+
tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
112+
if tensor.numel() == 0:
113+
# Use reshape, so we don't end up creating a new tensor that
114+
# does not depend on the inputs (and consequently confuses jit)
115+
tensor = tensor.reshape((0, box_dim)).to(dtype=torch.float32,
116+
device=device)
117+
assert tensor.dim() == 2 and tensor.size(-1) == box_dim, tensor.size()
118+
119+
if tensor.shape[-1] == 6:
120+
# If the dimension of boxes is 6, we expand box_dim by padding
121+
# (0, 0, 0) as a fake euler angle.
122+
assert box_dim == 6
123+
fake_rot = tensor.new_zeros(tensor.shape[0], 3)
124+
tensor = torch.cat((tensor, fake_rot), dim=-1)
125+
self.box_dim = box_dim + 3
126+
elif tensor.shape[-1] == 7:
127+
assert box_dim == 7
128+
fake_euler = tensor.new_zeros(tensor.shape[0], 2)
129+
tensor = torch.cat((tensor, fake_euler), dim=-1)
130+
self.box_dim = box_dim + 2
131+
else:
132+
assert tensor.shape[-1] == 9
133+
self.box_dim = box_dim
134+
self.tensor = tensor.clone()
135+
136+
self.origin = origin
137+
if origin != (0.5, 0.5, 0.5):
138+
dst = self.tensor.new_tensor((0.5, 0.5, 0.5))
139+
src = self.tensor.new_tensor(origin)
140+
self.tensor[:, :3] += self.tensor[:, 3:6] * (dst - src)
141+
self.with_yaw = with_yaw
142+
143+
def __len__(self) -> int:
144+
"""int: Number of boxes in the current object."""
145+
return self.tensor.shape[0]
146+
147+
def __getitem__(self, item: Union[int, slice, np.ndarray, torch.Tensor]):
148+
"""
149+
Args:
150+
item (int or slice or np.ndarray or Tensor): Index of boxes.
151+
152+
Note:
153+
The following usage are allowed:
154+
155+
1. `new_boxes = boxes[3]`: Return a `Boxes` that contains only one
156+
box.
157+
2. `new_boxes = boxes[2:10]`: Return a slice of boxes.
158+
3. `new_boxes = boxes[vector]`: Where vector is a
159+
torch.BoolTensor with `length = len(boxes)`. Nonzero elements in
160+
the vector will be selected.
161+
162+
Note that the returned Boxes might share storage with this Boxes,
163+
subject to PyTorch's indexing semantics.
164+
165+
Returns:
166+
:obj:`BaseInstance3DBoxes`: A new object of
167+
:class:`BaseInstance3DBoxes` after indexing.
168+
"""
169+
original_type = type(self)
170+
if isinstance(item, int):
171+
return original_type(self.tensor[item].view(1, -1),
172+
box_dim=self.box_dim,
173+
with_yaw=self.with_yaw)
174+
b = self.tensor[item]
175+
assert b.dim() == 2, \
176+
f'Indexing on Boxes with {item} failed to return a matrix!'
177+
return original_type(b, box_dim=self.box_dim, with_yaw=self.with_yaw)
178+
179+
@property
180+
def dims(self) -> torch.Tensor:
181+
"""Tensor: Size dimensions of each box in shape (N, 3)."""
182+
return self.tensor[:, 3:6]
183+
184+
@classmethod
185+
def overlaps(cls, boxes1, boxes2, mode='iou', eps=1e-4):
186+
"""Calculate 3D overlaps of two boxes.
187+
188+
Note:
189+
This function calculates the overlaps between ``boxes1`` and
190+
``boxes2``, ``boxes1`` and ``boxes2`` should be in the same type.
191+
192+
Args:
193+
boxes1 (:obj:`EulerInstance3DBoxes`): Boxes 1 contain N boxes.
194+
boxes2 (:obj:`EulerInstance3DBoxes`): Boxes 2 contain M boxes.
195+
mode (str): Mode of iou calculation. Defaults to 'iou'.
196+
eps (bool): Epsilon. Defaults to 1e-4.
197+
198+
Returns:
199+
torch.Tensor: Calculated 3D overlaps of the boxes.
200+
"""
201+
assert isinstance(boxes1, EulerDepthInstance3DBoxes)
202+
assert isinstance(boxes2, EulerDepthInstance3DBoxes)
203+
assert type(boxes1) == type(boxes2), '"boxes1" and "boxes2" should' \
204+
f'be in the same type, got {type(boxes1)} and {type(boxes2)}.'
205+
206+
assert mode in ['iou']
207+
208+
rows = len(boxes1)
209+
cols = len(boxes2)
210+
if rows * cols == 0:
211+
return boxes1.tensor.new(rows, cols)
212+
213+
corners1 = boxes1.corners
214+
corners2 = boxes2.corners
215+
_, iou3d = box3d_overlap(corners1, corners2, eps=eps)
216+
return iou3d
217+
218+
@property
219+
def corners(self):
220+
"""torch.Tensor: Coordinates of corners of all the boxes
221+
in shape (N, 8, 3).
222+
223+
Convert the boxes to corners in clockwise order, in form of
224+
``(x0y0z0, x0y0z1, x0y1z1, x0y1z0, x1y0z0, x1y0z1, x1y1z1, x1y1z0)``
225+
226+
.. code-block:: none
227+
228+
up z
229+
front y ^
230+
/ |
231+
/ |
232+
(x0, y1, z1) + ----------- + (x1, y1, z1)
233+
/| / |
234+
/ | / |
235+
(x0, y0, z1) + ----------- + + (x1, y1, z0)
236+
| / . | /
237+
| / origin | /
238+
(x0, y0, z0) + ----------- + --------> right x
239+
(x1, y0, z0)
240+
"""
241+
if self.tensor.numel() == 0:
242+
return torch.empty([0, 8, 3], device=self.tensor.device)
243+
244+
dims = self.dims
245+
corners_norm = torch.from_numpy(
246+
np.stack(np.unravel_index(np.arange(8), [2] * 3),
247+
axis=1)).to(device=dims.device, dtype=dims.dtype)
248+
249+
corners_norm = corners_norm[[0, 1, 3, 2, 4, 5, 7, 6]]
250+
# use relative origin
251+
assert self.origin == (0.5, 0.5, 0.5), \
252+
'self.origin != (0.5, 0.5, 0.5) needs to be checked!'
253+
corners_norm = corners_norm - dims.new_tensor(self.origin)
254+
corners = dims.view([-1, 1, 3]) * corners_norm.reshape([1, 8, 3])
255+
256+
# rotate
257+
corners = rotation_3d_in_euler(corners, self.tensor[:, 6:])
258+
259+
corners += self.tensor[:, :3].view(-1, 1, 3)
260+
return corners
261+
262+
263+
def parse_args():
264+
parser = argparse.ArgumentParser(
265+
description='MMDet3D test (and eval) a model')
266+
parser.add_argument('results_file', help='the results pkl file')
267+
parser.add_argument('ann_file', help='annoations json file')
268+
269+
parser.add_argument('--iou_thr',
270+
type=list,
271+
default=[0.25, 0.5],
272+
help='the IoU threshold during evaluation')
273+
274+
args = parser.parse_args()
275+
return args
276+
277+
278+
def ground_eval(gt_annos, det_annos, iou_thr):
279+
280+
assert len(det_annos) == len(gt_annos)
281+
282+
pred = {}
283+
gt = {}
284+
285+
object_types = [
286+
'Easy', 'Hard', 'View-Dep', 'View-Indep', 'Unique', 'Multi', 'Overall'
287+
]
288+
289+
for t in iou_thr:
290+
for object_type in object_types:
291+
pred.update({object_type + '@' + str(t): 0})
292+
gt.update({object_type + '@' + str(t): 1e-14})
293+
294+
for sample_id in range(len(det_annos)):
295+
det_anno = det_annos[sample_id]
296+
gt_anno = gt_annos[sample_id]['ann_info']
297+
298+
bboxes = det_anno['bboxes_3d']
299+
gt_bboxes = gt_anno['gt_bboxes_3d']
300+
bboxes = EulerDepthInstance3DBoxes(bboxes, origin=(0.5, 0.5, 0.5))
301+
gt_bboxes = EulerDepthInstance3DBoxes(gt_bboxes,
302+
origin=(0.5, 0.5, 0.5))
303+
scores = bboxes.tensor.new_tensor(
304+
det_anno['scores_3d']) # (num_query, )
305+
306+
view_dep = gt_anno['is_view_dep']
307+
hard = gt_anno['is_hard']
308+
unique = gt_anno['is_unique']
309+
310+
box_index = scores.argsort(dim=-1, descending=True)[:10]
311+
top_bboxes = bboxes[box_index]
312+
313+
iou = top_bboxes.overlaps(top_bboxes, gt_bboxes) # (num_query, 1)
314+
315+
for t in iou_thr:
316+
threshold = iou > t
317+
found = int(threshold.any())
318+
if view_dep:
319+
gt['View-Dep@' + str(t)] += 1
320+
pred['View-Dep@' + str(t)] += found
321+
else:
322+
gt['View-Indep@' + str(t)] += 1
323+
pred['View-Indep@' + str(t)] += found
324+
if hard:
325+
gt['Hard@' + str(t)] += 1
326+
pred['Hard@' + str(t)] += found
327+
else:
328+
gt['Easy@' + str(t)] += 1
329+
pred['Easy@' + str(t)] += found
330+
if unique:
331+
gt['Unique@' + str(t)] += 1
332+
pred['Unique@' + str(t)] += found
333+
else:
334+
gt['Multi@' + str(t)] += 1
335+
pred['Multi@' + str(t)] += found
336+
337+
gt['Overall@' + str(t)] += 1
338+
pred['Overall@' + str(t)] += found
339+
340+
header = ['Type']
341+
header.extend(object_types)
342+
ret_dict = {}
343+
344+
for t in iou_thr:
345+
table_columns = [['results']]
346+
for object_type in object_types:
347+
metric = object_type + '@' + str(t)
348+
value = pred[metric] / max(gt[metric], 1)
349+
ret_dict[metric] = value
350+
table_columns.append([f'{value:.4f}'])
351+
352+
table_data = [header]
353+
table_rows = list(zip(*table_columns))
354+
table_data += table_rows
355+
table = AsciiTable(table_data)
356+
table.inner_footing_row_border = True
357+
print_log('\n' + table.table)
358+
359+
return ret_dict
360+
361+
362+
def main():
363+
args = parse_args()
364+
preds = mmengine.load(args.results_file)['results']
365+
annotations = mmengine.load(args.ann_file)
366+
assert len(preds) == len(annotations)
367+
ground_eval(annotations, preds, args.iou_thr)
368+
369+
370+
if __name__ == '__main__':
371+
main()

0 commit comments

Comments
 (0)