Skip to content

Commit a61c937

Browse files
gkioxarifacebook-github-bot
authored andcommitted
add align modes for cubify
Summary: Add alignment modes for cubify operation. Reviewed By: nikhilaravi Differential Revision: D21393199 fbshipit-source-id: 7022044e591229a6ed5efc361fd3215e65f43f86
1 parent 8fc28ba commit a61c937

File tree

2 files changed

+290
-232
lines changed

2 files changed

+290
-232
lines changed

pytorch3d/ops/cubify.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def ravel_index(idx, dims) -> torch.Tensor:
4545

4646

4747
@torch.no_grad()
48-
def cubify(voxels, thresh, device=None) -> Meshes:
48+
def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
4949
r"""
5050
Converts a voxel to a mesh by replacing each occupied voxel with a cube
5151
consisting of 12 faces and 8 vertices. Shared vertices are merged, and
@@ -54,13 +54,38 @@ def cubify(voxels, thresh, device=None) -> Meshes:
5454
voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities.
5555
thresh: A scalar threshold. If a voxel occupancy is larger than
5656
thresh, the voxel is considered occupied.
57+
device: The device of the output meshes
58+
align: Defines the alignment of the mesh vertices and the grid locations.
59+
Has to be one of {"topleft", "corner", "center"}. See below for explanation.
60+
Default is "topleft".
5761
Returns:
5862
meshes: A Meshes object of the corresponding meshes.
63+
64+
65+
The alignment between the vertices of the cubified mesh and the voxel locations (or pixels)
66+
is defined by the choice of `align`. We support three modes, as shown below for a 2x2 grid:
67+
68+
X---X---- X-------X ---------
69+
| | | | | | | X | X |
70+
X---X---- --------- ---------
71+
| | | | | | | X | X |
72+
--------- X-------X ---------
73+
74+
topleft corner center
75+
76+
In the figure, X denote the grid locations and the squares represent the added cuboids.
77+
When `align="topleft"`, then the top left corner of each cuboid corresponds to the
78+
pixel coordinate of the input grid.
79+
When `align="corner"`, then the corners of the output mesh span the whole grid.
80+
When `align="center"`, then the grid locations form the center of the cuboids.
5981
"""
6082

6183
if device is None:
6284
device = voxels.device
6385

86+
if align not in ["topleft", "corner", "center"]:
87+
raise ValueError("Align mode must be one of (topleft, corner, center).")
88+
6489
if len(voxels) == 0:
6590
return Meshes(verts=[], faces=[])
6691

@@ -146,7 +171,7 @@ def cubify(voxels, thresh, device=None) -> Meshes:
146171

147172
# boolean to linear index
148173
# NF x 2
149-
linind = torch.nonzero(faces_idx)
174+
linind = torch.nonzero(faces_idx, as_tuple=False)
150175
# NF x 4
151176
nyxz = unravel_index(linind[:, 0], (N, H, W, D))
152177

@@ -170,11 +195,19 @@ def cubify(voxels, thresh, device=None) -> Meshes:
170195
torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1)
171196
)
172197
y = y.to(device=device, dtype=torch.float32)
173-
y = y * 2.0 / (H - 1.0) - 1.0
174198
x = x.to(device=device, dtype=torch.float32)
175-
x = x * 2.0 / (W - 1.0) - 1.0
176199
z = z.to(device=device, dtype=torch.float32)
177-
z = z * 2.0 / (D - 1.0) - 1.0
200+
201+
if align == "center":
202+
x = x - 0.5
203+
y = y - 0.5
204+
z = z - 0.5
205+
206+
margin = 0.0 if align == "corner" else 1.0
207+
y = y * 2.0 / (H - margin) - 1.0
208+
x = x * 2.0 / (W - margin) - 1.0
209+
z = z * 2.0 / (D - margin) - 1.0
210+
178211
# ((H+1)(W+1)(D+1)) x 3
179212
grid_verts = torch.stack((x, y, z), dim=3).view(-1, 3)
180213

@@ -196,7 +229,7 @@ def cubify(voxels, thresh, device=None) -> Meshes:
196229
idlenum = idleverts.cumsum(1)
197230

198231
verts_list = [
199-
grid_verts.index_select(0, (idleverts[n] == 0).nonzero()[:, 0])
232+
grid_verts.index_select(0, (idleverts[n] == 0).nonzero(as_tuple=False)[:, 0])
200233
for n in range(N)
201234
]
202235
faces_list = [nface - idlenum[n][nface] for n, nface in enumerate(faces_list)]

0 commit comments

Comments
 (0)