Skip to content

Commit

Permalink
[Enhance] Support tensor-like operations for BaseInstance3DBoxes and …
Browse files Browse the repository at this point in the history
…BasePoints (#2501)
  • Loading branch information
Xiangxu-0103 authored May 10, 2023
1 parent 864ed34 commit 74768e2
Show file tree
Hide file tree
Showing 12 changed files with 113 additions and 30 deletions.
2 changes: 1 addition & 1 deletion mmdet3d/datasets/convert_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def convert_annos(info: dict, cam_idx: int) -> dict:
Box3DMode.LIDAR, np.linalg.inv(rect @ lidar2cam0), correct_yaw=True)
# convert gt_bboxes_3d to cam coordinates
gt_bboxes_3d = gt_bboxes_3d.convert_to(
Box3DMode.CAM, rect @ lidar2cami, correct_yaw=True).tensor.numpy()
Box3DMode.CAM, rect @ lidar2cami, correct_yaw=True).numpy()
converted_annos['location'] = gt_bboxes_3d[:, :3]
converted_annos['dimensions'] = gt_bboxes_3d[:, 3:6]
converted_annos['rotation_y'] = gt_bboxes_3d[:, 6]
Expand Down
2 changes: 1 addition & 1 deletion mmdet3d/datasets/transforms/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def _remove_close(self,
if isinstance(points, np.ndarray):
points_numpy = points
elif isinstance(points, BasePoints):
points_numpy = points.tensor.numpy()
points_numpy = points.numpy()
else:
raise NotImplementedError
x_filt = np.abs(points_numpy[:, 0]) < radius
Expand Down
13 changes: 6 additions & 7 deletions mmdet3d/datasets/transforms/transforms_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,13 +416,13 @@ def transform(self, input_dict: dict) -> dict:
gt_bboxes_2d = input_dict['gt_bboxes']
# Assume for now 3D & 2D bboxes are the same
sampled_dict = self.db_sampler.sample_all(
gt_bboxes_3d.tensor.numpy(),
gt_bboxes_3d.numpy(),
gt_labels_3d,
gt_bboxes_2d=gt_bboxes_2d,
img=img)
else:
sampled_dict = self.db_sampler.sample_all(
gt_bboxes_3d.tensor.numpy(),
gt_bboxes_3d.numpy(),
gt_labels_3d,
img=None,
ground_plane=ground_plane)
Expand All @@ -435,8 +435,7 @@ def transform(self, input_dict: dict) -> dict:
gt_labels_3d = np.concatenate([gt_labels_3d, sampled_gt_labels],
axis=0)
gt_bboxes_3d = gt_bboxes_3d.new_box(
np.concatenate(
[gt_bboxes_3d.tensor.numpy(), sampled_gt_bboxes_3d]))
np.concatenate([gt_bboxes_3d.numpy(), sampled_gt_bboxes_3d]))

points = self.remove_points_in_boxes(points, sampled_gt_bboxes_3d)
# check the points dimension
Expand Down Expand Up @@ -515,8 +514,8 @@ def transform(self, input_dict: dict) -> dict:
points = input_dict['points']

# TODO: this is inplace operation
numpy_box = gt_bboxes_3d.tensor.numpy()
numpy_points = points.tensor.numpy()
numpy_box = gt_bboxes_3d.numpy()
numpy_points = points.numpy()

noise_per_object_v3_(
numpy_box,
Expand Down Expand Up @@ -1547,7 +1546,7 @@ def transform(self, results: dict) -> dict:
# Extend points with seg and mask fields
map_fields2dim = []
start_dim = original_dim
points_numpy = points.tensor.numpy()
points_numpy = points.numpy()
extra_channel = [points_numpy]
for idx, key in enumerate(results['pts_mask_fields']):
map_fields2dim.append((key, idx + start_dim))
Expand Down
4 changes: 2 additions & 2 deletions mmdet3d/evaluation/metrics/kitti_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,8 +634,8 @@ def convert_valid_bboxes(self, box_dict: dict, info: dict) -> dict:
return dict(
bbox=box_2d_preds[valid_inds, :].numpy(),
pred_box_type_3d=type(box_preds),
box3d_camera=box_preds_camera[valid_inds].tensor.numpy(),
box3d_lidar=box_preds_lidar[valid_inds].tensor.numpy(),
box3d_camera=box_preds_camera[valid_inds].numpy(),
box3d_lidar=box_preds_lidar[valid_inds].numpy(),
scores=scores[valid_inds].numpy(),
label_preds=labels[valid_inds].numpy(),
sample_idx=sample_idx)
Expand Down
8 changes: 4 additions & 4 deletions mmdet3d/evaluation/metrics/waymo_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,8 @@ def merge_multi_view_boxes(self, box_dict_per_frame: List[dict],
# Note: bbox is meaningless in final evaluation, set to 0
merged_box_dict = dict(
bbox=np.zeros([box_preds_lidar.tensor.shape[0], 4]),
box3d_camera=box_preds_camera.tensor.numpy(),
box3d_lidar=box_preds_lidar.tensor.numpy(),
box3d_camera=box_preds_camera.numpy(),
box3d_lidar=box_preds_lidar.numpy(),
scores=scores.numpy(),
label_preds=labels.numpy(),
sample_idx=box_dict['sample_idx'],
Expand Down Expand Up @@ -694,8 +694,8 @@ def convert_valid_bboxes(self, box_dict: dict, info: dict) -> dict:
return dict(
bbox=box_2d_preds[valid_inds, :].numpy(),
pred_box_type_3d=type(box_preds),
box3d_camera=box_preds_camera[valid_inds].tensor.numpy(),
box3d_lidar=box_preds_lidar[valid_inds].tensor.numpy(),
box3d_camera=box_preds_camera[valid_inds].numpy(),
box3d_lidar=box_preds_lidar[valid_inds].numpy(),
scores=scores[valid_inds].numpy(),
label_preds=labels[valid_inds].numpy(),
sample_idx=sample_idx)
Expand Down
46 changes: 45 additions & 1 deletion mmdet3d/structures/bbox_3d/base_box3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class BaseInstance3DBoxes:
boxes.
"""

YAW_AXIS: int = 0

def __init__(
self,
tensor: Union[Tensor, np.ndarray, Sequence[Sequence[float]]],
Expand Down Expand Up @@ -77,6 +79,11 @@ def __init__(
src = self.tensor.new_tensor(origin)
self.tensor[:, :3] += self.tensor[:, 3:6] * (dst - src)

@property
def shape(self) -> torch.Size:
"""torch.Size: Shape of boxes."""
return self.tensor.shape

@property
def volume(self) -> Tensor:
"""Tensor: A vector with volume of each box in shape (N, )."""
Expand Down Expand Up @@ -406,6 +413,10 @@ def cat(cls, boxes_list: Sequence['BaseInstance3DBoxes']
with_yaw=boxes_list[0].with_yaw)
return cat_boxes

def numpy(self) -> np.ndarray:
"""Reload ``numpy`` from self.tensor."""
return self.tensor.numpy()

def to(self, device: Union[str, torch.device], *args,
**kwargs) -> 'BaseInstance3DBoxes':
"""Convert current boxes to a specific device.
Expand All @@ -415,14 +426,36 @@ def to(self, device: Union[str, torch.device], *args,
Returns:
:obj:`BaseInstance3DBoxes`: A new boxes object on the specific
device.
device.
"""
original_type = type(self)
return original_type(
self.tensor.to(device, *args, **kwargs),
box_dim=self.box_dim,
with_yaw=self.with_yaw)

def cpu(self) -> 'BaseInstance3DBoxes':
"""Convert current boxes to cpu device.
Returns:
:obj:`BaseInstance3DBoxes`: A new boxes object on the cpu device.
"""
original_type = type(self)
return original_type(
self.tensor.cpu(), box_dim=self.box_dim, with_yaw=self.with_yaw)

def cuda(self, *args, **kwargs) -> 'BaseInstance3DBoxes':
"""Convert current boxes to cuda device.
Returns:
:obj:`BaseInstance3DBoxes`: A new boxes object on the cuda device.
"""
original_type = type(self)
return original_type(
self.tensor.cuda(*args, **kwargs),
box_dim=self.box_dim,
with_yaw=self.with_yaw)

def clone(self) -> 'BaseInstance3DBoxes':
"""Clone the boxes.
Expand All @@ -434,6 +467,17 @@ def clone(self) -> 'BaseInstance3DBoxes':
return original_type(
self.tensor.clone(), box_dim=self.box_dim, with_yaw=self.with_yaw)

def detach(self) -> 'BaseInstance3DBoxes':
"""Detach the boxes.
Returns:
:obj:`BaseInstance3DBoxes`: Box object with the same properties as
self.
"""
original_type = type(self)
return original_type(
self.tensor.detach(), box_dim=self.box_dim, with_yaw=self.with_yaw)

@property
def device(self) -> torch.device:
"""torch.device: The device of the boxes are on."""
Expand Down
40 changes: 40 additions & 0 deletions mmdet3d/structures/points/base_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,10 @@ def cat(cls, points_list: Sequence['BasePoints']) -> 'BasePoints':
attribute_dims=points_list[0].attribute_dims)
return cat_points

def numpy(self) -> np.ndarray:
"""Reload ``numpy`` from self.tensor."""
return self.tensor.numpy()

def to(self, device: Union[str, torch.device], *args,
**kwargs) -> 'BasePoints':
"""Convert current points to a specific device.
Expand All @@ -433,6 +437,30 @@ def to(self, device: Union[str, torch.device], *args,
points_dim=self.points_dim,
attribute_dims=self.attribute_dims)

def cpu(self) -> 'BasePoints':
"""Convert current points to cpu device.
Returns:
:obj:`BasePoints`: A new points object on the cpu device.
"""
original_type = type(self)
return original_type(
self.tensor.cpu(),
points_dim=self.points_dim,
attribute_dims=self.attribute_dims)

def cuda(self, *args, **kwargs) -> 'BasePoints':
"""Convert current points to cuda device.
Returns:
:obj:`BasePoints`: A new points object on the cuda device.
"""
original_type = type(self)
return original_type(
self.tensor.cuda(*args, **kwargs),
points_dim=self.points_dim,
attribute_dims=self.attribute_dims)

def clone(self) -> 'BasePoints':
"""Clone the points.
Expand All @@ -445,6 +473,18 @@ def clone(self) -> 'BasePoints':
points_dim=self.points_dim,
attribute_dims=self.attribute_dims)

def detach(self) -> 'BasePoints':
"""Detach the points.
Returns:
:obj:`BasePoints`: Point object with the same properties as self.
"""
original_type = type(self)
return original_type(
self.tensor.detach(),
points_dim=self.points_dim,
attribute_dims=self.attribute_dims)

@property
def device(self) -> torch.device:
"""torch.device: The device of the points are on."""
Expand Down
2 changes: 1 addition & 1 deletion projects/PETR/petr/transforms_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def transform(self, results):

# TODO: support translation
if not self.reverse_angle:
gt_bboxes_3d = results['gt_bboxes_3d'].tensor.numpy()
gt_bboxes_3d = results['gt_bboxes_3d'].numpy()
gt_bboxes_3d[:, 6] -= 2 * rot_angle
results['gt_bboxes_3d'] = LiDARInstance3DBoxes(
gt_bboxes_3d, box_dim=9)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_structures/test_bbox/test_box3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def test_lidar_boxes3d():
points_np, rot_mat_T_np = boxes.rotate(0.13603681398218053, points_np)
lidar_points = LiDARPoints(points_np)
lidar_points, rot_mat_T_np = boxes.rotate(rot_mat, lidar_points)
points_np = lidar_points.tensor.numpy()
points_np = lidar_points.numpy()

assert np.allclose(points_np, expected_points_np, 1e-3)
assert np.allclose(rot_mat_T_np, expected_rot_mat_T_np, 1e-3)
Expand Down Expand Up @@ -667,8 +667,8 @@ def test_boxes_conversion():
assert torch.allclose(lidar_to_cam_box, camera_boxes.tensor)

# test numpy convert
cam_to_lidar_box = Box3DMode.convert(camera_boxes.tensor.numpy(),
Box3DMode.CAM, Box3DMode.LIDAR,
cam_to_lidar_box = Box3DMode.convert(camera_boxes.numpy(), Box3DMode.CAM,
Box3DMode.LIDAR,
rt_mat.inverse().numpy())
assert np.allclose(cam_to_lidar_box, expected_tensor.numpy())

Expand Down Expand Up @@ -931,7 +931,7 @@ def test_camera_boxes3d():
torch.tensor(-0.13603681398218053), points_np)
camera_points = CameraPoints(points_np, points_dim=4)
camera_points, rot_mat_T_np = boxes.rotate(rot_mat, camera_points)
points_np = camera_points.tensor.numpy()
points_np = camera_points.numpy()
assert np.allclose(points_np, expected_points_np, 1e-3)
assert np.allclose(rot_mat_T_np, expected_rot_mat_T_np, 1e-3)

Expand Down Expand Up @@ -1338,7 +1338,7 @@ def test_depth_boxes3d():
points_np, rot_mat_T_np = boxes.rotate(-0.022998953275003075, points_np)
depth_points = DepthPoints(points_np, points_dim=4)
depth_points, rot_mat_T_np = boxes.rotate(rot_mat, depth_points)
points_np = depth_points.tensor.numpy()
points_np = depth_points.numpy()
expected_rot_mat_T_np = expected_rot_mat_T_np.T
assert torch.allclose(boxes.tensor, expected_tensor, 1e-3)
assert np.allclose(points_np, expected_points_np, 1e-3)
Expand Down
8 changes: 4 additions & 4 deletions tools/dataset_converters/create_gt_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ def create_groundtruth_database(dataset_class_name,
example = dataset.pipeline(data_info)
annos = example['ann_info']
image_idx = example['sample_idx']
points = example['points'].tensor.numpy()
gt_boxes_3d = annos['gt_bboxes_3d'].tensor.numpy()
points = example['points'].numpy()
gt_boxes_3d = annos['gt_bboxes_3d'].numpy()
names = [dataset.metainfo['classes'][i] for i in annos['gt_labels_3d']]
group_dict = dict()
if 'group_ids' in annos:
Expand Down Expand Up @@ -406,8 +406,8 @@ def create_single(self, input_dict):
example = self.pipeline(input_dict)
annos = example['ann_info']
image_idx = example['sample_idx']
points = example['points'].tensor.numpy()
gt_boxes_3d = annos['gt_bboxes_3d'].tensor.numpy()
points = example['points'].numpy()
gt_boxes_3d = annos['gt_bboxes_3d'].numpy()
names = [
self.dataset.metainfo['classes'][i] for i in annos['gt_labels_3d']
]
Expand Down
4 changes: 2 additions & 2 deletions tools/deployment/mmdet3d_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ def postprocess(self, data):
for pts_index, result in enumerate(data):
output.append([])
if 'pts_bbox' in result.keys():
pred_bboxes = result['pts_bbox']['boxes_3d'].tensor.numpy()
pred_bboxes = result['pts_bbox']['boxes_3d'].numpy()
pred_scores = result['pts_bbox']['scores_3d'].numpy()
else:
pred_bboxes = result['boxes_3d'].tensor.numpy()
pred_bboxes = result['boxes_3d'].numpy()
pred_scores = result['scores_3d'].numpy()

index = pred_scores > self.threshold
Expand Down
4 changes: 2 additions & 2 deletions tools/deployment/test_torchserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def main(args):
model_result, _ = inference_detector(model, args.pcd)
# filter the 3d bboxes whose scores > 0.5
if 'pts_bbox' in model_result[0].keys():
pred_bboxes = model_result[0]['pts_bbox']['boxes_3d'].tensor.numpy()
pred_bboxes = model_result[0]['pts_bbox']['boxes_3d'].numpy()
pred_scores = model_result[0]['pts_bbox']['scores_3d'].numpy()
else:
pred_bboxes = model_result[0]['boxes_3d'].tensor.numpy()
pred_bboxes = model_result[0]['boxes_3d'].numpy()
pred_scores = model_result[0]['scores_3d'].numpy()
model_result = pred_bboxes[pred_scores > 0.5]

Expand Down

0 comments on commit 74768e2

Please sign in to comment.