-
Notifications
You must be signed in to change notification settings - Fork 25
/
model.py
131 lines (102 loc) · 5.8 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import numpy as np
import torch
import torch.nn as nn
from models.mean_vfe import MeanVFE
from models.spconv_backbone_voxelnext import VoxelResBackBone8xVoxelNeXt
from models.voxelnext_head import VoxelNeXtHead
from utils.image_projection import _proj_voxel_image
from segment_anything import SamPredictor, sam_model_registry
class VoxelNeXt(nn.Module):
def __init__(self, model_cfg):
super().__init__()
input_channels = model_cfg.get('INPUT_CHANNELS', 5)
grid_size = np.array(model_cfg.get('GRID_SIZE', [1440, 1440, 40]))
class_names = model_cfg.get('CLASS_NAMES')
kernel_size_head = model_cfg.get('KERNEL_SIZE_HEAD', 1)
self.point_cloud_range = torch.Tensor(model_cfg.get('POINT_CLOUD_RANGE', [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]))
self.voxel_size = torch.Tensor(model_cfg.get('VOXEL_SIZE', [0.075, 0.075, 0.2]))
CLASS_NAMES_EACH_HEAD = model_cfg.get('CLASS_NAMES_EACH_HEAD')
SEPARATE_HEAD_CFG = model_cfg.get('SEPARATE_HEAD_CFG')
POST_PROCESSING = model_cfg.get('POST_PROCESSING')
self.voxelization = MeanVFE()
self.backbone_3d = VoxelResBackBone8xVoxelNeXt(input_channels, grid_size)
self.dense_head = VoxelNeXtHead(class_names, self.point_cloud_range, self.voxel_size, kernel_size_head,
CLASS_NAMES_EACH_HEAD, SEPARATE_HEAD_CFG, POST_PROCESSING)
class Model(nn.Module):
def __init__(self, model_cfg, device="cuda"):
super().__init__()
sam_type = model_cfg.get('SAM_TYPE', "vit_b")
sam_checkpoint = model_cfg.get('SAM_CHECKPOINT', "/data/sam_vit_b_01ec64.pth")
sam = sam_model_registry[sam_type](checkpoint=sam_checkpoint).to(device=device)
self.sam_predictor = SamPredictor(sam)
voxelnext_checkpoint = model_cfg.get('VOXELNEXT_CHECKPOINT', "/data/voxelnext_nuscenes_kernel1.pth")
model_dict = torch.load(voxelnext_checkpoint)
self.voxelnext = VoxelNeXt(model_cfg).to(device=device)
self.voxelnext.load_state_dict(model_dict)
self.point_features = {}
self.device = device
def image_embedding(self, image):
self.sam_predictor.set_image(image)
def point_embedding(self, data_dict, image_id):
data_dict['voxels'] = torch.Tensor(data_dict['voxels']).to(self.device)
data_dict['voxel_num_points'] = torch.Tensor(data_dict['voxel_num_points']).to(self.device)
data_dict['voxel_coords'] = torch.Tensor(data_dict['voxel_coords']).to(self.device)
data_dict = self.voxelnext.voxelization(data_dict)
n_voxels = data_dict['voxel_coords'].shape[0]
device = data_dict['voxel_coords'].device
dtype = data_dict['voxel_coords'].dtype
data_dict['voxel_coords'] = torch.cat([torch.zeros((n_voxels, 1), device=device, dtype=dtype), data_dict['voxel_coords']], dim=1)
data_dict['batch_size'] = 1
if not image_id in self.point_features:
data_dict = self.voxelnext.backbone_3d(data_dict)
self.point_features[image_id] = data_dict
else:
data_dict = self.point_features[image_id]
pred_dicts = self.voxelnext.dense_head(data_dict)
voxel_coords = data_dict['out_voxels'][pred_dicts[0]['voxel_ids'].squeeze(-1)] * self.voxelnext.dense_head.feature_map_stride
return pred_dicts, voxel_coords
def generate_3D_box(self, lidar2img_rt, mask, voxel_coords, pred_dicts, quality_score=0.1):
device = voxel_coords.device
points_image, depth = _proj_voxel_image(voxel_coords, lidar2img_rt, self.voxelnext.voxel_size.to(device), self.voxelnext.point_cloud_range.to(device))
points = points_image.permute(1, 0).int().cpu().numpy()
selected_voxels = torch.zeros_like(depth).squeeze(0)
for i in range(points.shape[0]):
point = points[i]
if point[0] < 0 or point[1] < 0 or point[0] >= mask.shape[1] or point[1] >= mask.shape[0]:
continue
if mask[point[1], point[0]]:
selected_voxels[i] = 1
mask_extra = (pred_dicts[0]['pred_scores'] > quality_score)
if mask_extra.sum() == 0:
print("no high quality 3D box related.")
return None
selected_voxels *= mask_extra
if selected_voxels.sum() > 0:
selected_box_id = pred_dicts[0]['pred_scores'][selected_voxels.bool()].argmax()
selected_box = pred_dicts[0]['pred_boxes'][selected_voxels.bool()][selected_box_id]
else:
grid_x, grid_y = torch.meshgrid(torch.arange(mask.shape[0]), torch.arange(mask.shape[1]))
mask_x, mask_y = grid_x[mask], grid_y[mask]
mask_center = torch.Tensor([mask_y.float().mean(), mask_x.float().mean()]).to(
pred_dicts[0]['pred_boxes'].device).unsqueeze(1)
dist = ((points_image - mask_center) ** 2).sum(0)
selected_id = dist[mask_extra].argmin()
selected_box = pred_dicts[0]['pred_boxes'][mask_extra][selected_id]
return selected_box
def forward(self, image, point_dict, prompt_point, lidar2img_rt, image_id, quality_score=0.1):
self.image_embedding(image)
pred_dicts, voxel_coords = self.point_embedding(point_dict, image_id)
masks, scores, _ = self.sam_predictor.predict(point_coords=prompt_point, point_labels=np.array([1]))
mask = masks[0]
box3d = self.generate_3D_box(lidar2img_rt, mask, voxel_coords, pred_dicts, quality_score=quality_score)
return mask, box3d
if __name__ == '__main__':
cfg_dataset = 'nuscenes_dataset.yaml'
cfg_model = 'voxelnext.yaml'
dataset_cfg = cfg_from_yaml_file(cfg_dataset, cfg)
model_cfg = cfg_from_yaml_file(cfg_model, cfg)
nuscenes_dataset = NuScenesDataset(dataset_cfg)
model = Model(model_cfg)
index = 0
data_dict = nuscenes_dataset._get_points(index)
model.point_embedding(data_dict)