@@ -45,7 +45,7 @@ def ravel_index(idx, dims) -> torch.Tensor:
45
45
46
46
47
47
@torch .no_grad ()
48
- def cubify (voxels , thresh , device = None ) -> Meshes :
48
+ def cubify (voxels , thresh , device = None , align : str = "topleft" ) -> Meshes :
49
49
r"""
50
50
Converts a voxel to a mesh by replacing each occupied voxel with a cube
51
51
consisting of 12 faces and 8 vertices. Shared vertices are merged, and
@@ -54,13 +54,38 @@ def cubify(voxels, thresh, device=None) -> Meshes:
54
54
voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities.
55
55
thresh: A scalar threshold. If a voxel occupancy is larger than
56
56
thresh, the voxel is considered occupied.
57
+ device: The device of the output meshes
58
+ align: Defines the alignment of the mesh vertices and the grid locations.
59
+ Has to be one of {"topleft", "corner", "center"}. See below for explanation.
60
+ Default is "topleft".
57
61
Returns:
58
62
meshes: A Meshes object of the corresponding meshes.
63
+
64
+
65
+ The alignment between the vertices of the cubified mesh and the voxel locations (or pixels)
66
+ is defined by the choice of `align`. We support three modes, as shown below for a 2x2 grid:
67
+
68
+ X---X---- X-------X ---------
69
+ | | | | | | | X | X |
70
+ X---X---- --------- ---------
71
+ | | | | | | | X | X |
72
+ --------- X-------X ---------
73
+
74
+ topleft corner center
75
+
76
+ In the figure, X denote the grid locations and the squares represent the added cuboids.
77
+ When `align="topleft"`, then the top left corner of each cuboid corresponds to the
78
+ pixel coordinate of the input grid.
79
+ When `align="corner"`, then the corners of the output mesh span the whole grid.
80
+ When `align="center"`, then the grid locations form the center of the cuboids.
59
81
"""
60
82
61
83
if device is None :
62
84
device = voxels .device
63
85
86
+ if align not in ["topleft" , "corner" , "center" ]:
87
+ raise ValueError ("Align mode must be one of (topleft, corner, center)." )
88
+
64
89
if len (voxels ) == 0 :
65
90
return Meshes (verts = [], faces = [])
66
91
@@ -146,7 +171,7 @@ def cubify(voxels, thresh, device=None) -> Meshes:
146
171
147
172
# boolean to linear index
148
173
# NF x 2
149
- linind = torch .nonzero (faces_idx )
174
+ linind = torch .nonzero (faces_idx , as_tuple = False )
150
175
# NF x 4
151
176
nyxz = unravel_index (linind [:, 0 ], (N , H , W , D ))
152
177
@@ -170,11 +195,19 @@ def cubify(voxels, thresh, device=None) -> Meshes:
170
195
torch .arange (H + 1 ), torch .arange (W + 1 ), torch .arange (D + 1 )
171
196
)
172
197
y = y .to (device = device , dtype = torch .float32 )
173
- y = y * 2.0 / (H - 1.0 ) - 1.0
174
198
x = x .to (device = device , dtype = torch .float32 )
175
- x = x * 2.0 / (W - 1.0 ) - 1.0
176
199
z = z .to (device = device , dtype = torch .float32 )
177
- z = z * 2.0 / (D - 1.0 ) - 1.0
200
+
201
+ if align == "center" :
202
+ x = x - 0.5
203
+ y = y - 0.5
204
+ z = z - 0.5
205
+
206
+ margin = 0.0 if align == "corner" else 1.0
207
+ y = y * 2.0 / (H - margin ) - 1.0
208
+ x = x * 2.0 / (W - margin ) - 1.0
209
+ z = z * 2.0 / (D - margin ) - 1.0
210
+
178
211
# ((H+1)(W+1)(D+1)) x 3
179
212
grid_verts = torch .stack ((x , y , z ), dim = 3 ).view (- 1 , 3 )
180
213
@@ -196,7 +229,7 @@ def cubify(voxels, thresh, device=None) -> Meshes:
196
229
idlenum = idleverts .cumsum (1 )
197
230
198
231
verts_list = [
199
- grid_verts .index_select (0 , (idleverts [n ] == 0 ).nonzero ()[:, 0 ])
232
+ grid_verts .index_select (0 , (idleverts [n ] == 0 ).nonzero (as_tuple = False )[:, 0 ])
200
233
for n in range (N )
201
234
]
202
235
faces_list = [nface - idlenum [n ][nface ] for n , nface in enumerate (faces_list )]
0 commit comments