Skip to content

Commit fb97ab1

Browse files
bottlerfacebook-github-bot
authored andcommitted
getitem for textures
Summary: Make Meshes.__getitem__ carry texture information to the new mesh. Reviewed By: gkioxari Differential Revision: D20283976 fbshipit-source-id: d9ee0580c11ac5b4384df9d8158a07e6eb8d00fe
1 parent 5a1d714 commit fb97ab1

File tree

3 files changed

+71
-2
lines changed

3 files changed

+71
-2
lines changed

pytorch3d/structures/meshes.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,10 +415,12 @@ def __getitem__(self, index):
415415
else:
416416
raise IndexError(index)
417417

418+
textures = None if self.textures is None else self.textures[index]
419+
418420
if torch.is_tensor(verts) and torch.is_tensor(faces):
419-
return Meshes(verts=[verts], faces=[faces])
421+
return Meshes(verts=[verts], faces=[faces], textures=textures)
420422
elif isinstance(verts, list) and isinstance(faces, list):
421-
return Meshes(verts=verts, faces=faces)
423+
return Meshes(verts=verts, faces=faces, textures=textures)
422424
else:
423425
raise ValueError("(verts, faces) not defined correctly")
424426

pytorch3d/structures/textures.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,14 @@ def __init__(
115115
self._verts_rgb_padded = verts_rgb
116116
self._maps_padded = maps
117117
self._num_faces_per_mesh = None
118+
self._set_num_faces_per_mesh()
118119

120+
def _set_num_faces_per_mesh(self) -> None:
121+
"""
122+
Determines and sets the number of textured faces for each mesh.
123+
"""
119124
if self._faces_uvs_padded is not None:
125+
faces_uvs = self._faces_uvs_padded
120126
self._num_faces_per_mesh = faces_uvs.gt(-1).all(-1).sum(-1).tolist()
121127

122128
def clone(self):
@@ -134,6 +140,18 @@ def to(self, device):
134140
setattr(self, k, v.to(device))
135141
return self
136142

143+
def __getitem__(self, index):
144+
other = Textures()
145+
for key in dir(self):
146+
value = getattr(self, key)
147+
if torch.is_tensor(value):
148+
if isinstance(index, int):
149+
setattr(other, key, value[index][None])
150+
else:
151+
setattr(other, key, value[index])
152+
other._set_num_faces_per_mesh()
153+
return other
154+
137155
def faces_uvs_padded(self) -> torch.Tensor:
138156
return self._faces_uvs_padded
139157

tests/test_texturing.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,55 @@ def test_clone(self):
167167
self.assertSeparate(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded)
168168
self.assertSeparate(tex._maps_padded, tex_cloned._maps_padded)
169169

170+
def test_getitem(self):
171+
N = 5
172+
V = 20
173+
source = {
174+
"maps": torch.rand(size=(N, 16, 16, 3)),
175+
"faces_uvs": torch.randint(size=(N, 10, 3), low=0, high=V),
176+
"verts_uvs": torch.rand((N, V, 2)),
177+
}
178+
tex = Textures(
179+
maps=source["maps"],
180+
faces_uvs=source["faces_uvs"],
181+
verts_uvs=source["verts_uvs"],
182+
)
183+
184+
verts = torch.rand(size=(N, V, 3))
185+
faces = torch.randint(size=(N, 10, 3), high=V)
186+
meshes = Meshes(verts=verts, faces=faces, textures=tex)
187+
188+
def tryindex(index):
189+
tex2 = tex[index]
190+
meshes2 = meshes[index]
191+
tex_from_meshes = meshes2.textures
192+
for item in source:
193+
basic = source[item][index]
194+
from_texture = getattr(tex2, item + "_padded")()
195+
from_meshes = getattr(tex_from_meshes, item + "_padded")()
196+
if isinstance(index, int):
197+
basic = basic[None]
198+
self.assertClose(basic, from_texture)
199+
self.assertClose(basic, from_meshes)
200+
self.assertEqual(
201+
from_texture.ndim, getattr(tex, item + "_padded")().ndim
202+
)
203+
if item == "faces_uvs":
204+
faces_uvs_list = tex_from_meshes.faces_uvs_list()
205+
self.assertEqual(basic.shape[0], len(faces_uvs_list))
206+
for i, faces_uvs in enumerate(faces_uvs_list):
207+
self.assertClose(faces_uvs, basic[i])
208+
209+
tryindex(2)
210+
tryindex(slice(0, 2, 1))
211+
index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool)
212+
tryindex(index)
213+
index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool)
214+
tryindex(index)
215+
index = torch.tensor([1, 2], dtype=torch.int64)
216+
tryindex(index)
217+
tryindex([2, 4])
218+
170219
def test_to(self):
171220
V = 20
172221
tex = Textures(

0 commit comments

Comments
 (0)