-
Notifications
You must be signed in to change notification settings - Fork 11
/
warp_func.py
67 lines (54 loc) · 2.85 KB
/
warp_func.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn.functional as F
import numpy as np
def homo_warping(src_fea, src_proj, ref_proj, depth_values):
# src_fea: [B, C, H, W]
# src_proj: [B, 4, 4]
# ref_proj: [B, 4, 4]
# depth_values: [B, Ndepth] o [B, Ndepth, H, W]
# out: [B, C, Ndepth, H, W]
batch, channels = src_fea.shape[0], src_fea.shape[1]
height, width = src_fea.shape[2], src_fea.shape[3]
with torch.no_grad():
proj = torch.matmul(src_proj, torch.inverse(ref_proj))
rot = proj[:, :3, :3] # [B,3,3]
trans = proj[:, :3, 3:4] # [B,3,1]
y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=src_fea.device),
torch.arange(0, width, dtype=torch.float32, device=src_fea.device)])
y, x = y.contiguous(), x.contiguous()
y, x = y.view(height * width), x.view(height * width)
xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W]
xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W]
rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W]
rot_depth_xyz = rot_xyz.unsqueeze(2) * depth_values.view(-1, 1, 1, height*width) # [B, 3, 1, H*W]
proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1) # [B, 3, Ndepth, H*W]
proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W]
proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1
proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1
proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2]
grid = proj_xy
warped_src_fea = F.grid_sample(src_fea, grid.view(batch, height, width, 2), mode='bilinear',
padding_mode='zeros')
warped_src_fea = warped_src_fea.view(batch, channels, height, width)
return warped_src_fea
def generate_points_from_depth(depth, proj):
'''
:param depth: (B, 1, H, W)
:param proj: (B, 4, 4)
:return: point_cloud (B, 3, H, W)
'''
batch, height, width = depth.shape[0], depth.shape[2], depth.shape[3]
inv_proj = torch.inverse(proj)
rot = inv_proj[:, :3, :3] # [B,3,3]
trans = inv_proj[:, :3, 3:4] # [B,3,1]
y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=depth.device),
torch.arange(0, width, dtype=torch.float32, device=depth.device)])
y, x = y.contiguous(), x.contiguous()
y, x = y.view(height * width), x.view(height * width)
xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W]
xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W]
rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W]
rot_depth_xyz = rot_xyz * depth.view(batch, 1, -1)
proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1) # [B, 3, H*W]
proj_xyz = proj_xyz.view(batch, 3, height, width)
return proj_xyz