Skip to content

Commit 8fe65d5

Browse files
bottlerfacebook-github-bot
authored andcommitted
Single function to load meshes from OBJs. join_meshes.
Summary: Create the textures and the Meshes object from OBJ files in a single call. There is functionality in OBJ files (like normals) which is ignored by this function. Reviewed By: gkioxari Differential Revision: D19691699 fbshipit-source-id: e26442ed80ff231b65b17d6c54c9d41e22b4e4a3
1 parent 23bb279 commit 8fe65d5

File tree

8 files changed

+218
-44
lines changed

8 files changed

+218
-44
lines changed

docs/notes/meshes_io.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=texture_image)
5555
# Initialise the mesh with textures
5656
meshes = Meshes(verts=[verts], faces=[faces.verts_idx], textures=tex)
5757
```
58+
59+
The `load_objs_as_meshes` function provides this procedure.
60+
5861
## PLY
5962

6063
Ply files are flexible in the way they store additional information, pytorch3d

docs/tutorials/render_textured_meshes.ipynb

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
"from skimage.io import imread\n",
8888
"\n",
8989
"# Util function for loading meshes\n",
90-
"from pytorch3d.io import load_obj\n",
90+
"from pytorch3d.io import load_objs_as_meshes\n",
9191
"\n",
9292
"# Data structures and functions for rendering\n",
9393
"from pytorch3d.structures import Meshes, Textures\n",
@@ -232,25 +232,8 @@
232232
"obj_filename = os.path.join(DATA_DIR, \"cow_mesh/cow.obj\")\n",
233233
"\n",
234234
"# Load obj file\n",
235-
"verts, faces, aux = load_obj(obj_filename)\n",
236-
"faces_idx = faces.verts_idx.to(device)\n",
237-
"verts = verts.to(device)\n",
238-
"\n",
239-
"# Get textures from the outputs of the load_obj function\n",
240-
"# the `aux` variable contains the texture maps and vertex uv coordinates. \n",
241-
"# Refer to the `obj_io.load_obj` function for full API reference. \n",
242-
"# Here we only have one texture map for the whole mesh. \n",
243-
"verts_uvs = aux.verts_uvs[None, ...].to(device) # (N, V, 2)\n",
244-
"faces_uvs = faces.textures_idx[None, ...].to(device) # (N, F, 3)\n",
245-
"tex_maps = aux.texture_images\n",
246-
"texture_image = list(tex_maps.values())[0]\n",
247-
"texture_image = texture_image[None, ...].to(device) # (N, H, W, 3)\n",
248-
"\n",
249-
"# Create a textures object\n",
250-
"tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=texture_image)\n",
251-
"\n",
252-
"# Create a meshes object with textures\n",
253-
"mesh = Meshes(verts=[verts], faces=[faces_idx], textures=tex)"
235+
"mesh = load_objs_as_meshes([obj_filename], device=device)\n",
236+
"texture_image=mesh.textures.maps_padded()"
254237
]
255238
},
256239
{

pytorch3d/io/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

33

4-
from .obj_io import load_obj, save_obj
4+
from .obj_io import load_obj, load_objs_as_meshes, save_obj
55
from .ply_io import load_ply, save_ply
66

77
__all__ = [k for k in globals().keys() if not k.startswith("_")]

pytorch3d/io/obj_io.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from fvcore.common.file_io import PathManager
1414
from PIL import Image
1515

16+
from pytorch3d.structures import Meshes, Textures, join_meshes
17+
1618

1719
def _read_image(file_name: str, format=None):
1820
"""
@@ -90,7 +92,7 @@ def _open_file(f):
9092

9193
def load_obj(f_obj, load_textures=True):
9294
"""
93-
Load a mesh and textures from a .obj and .mtl file.
95+
Load a mesh from a .obj file and optionally textures from a .mtl file.
9496
Currently this handles verts, faces, vertex texture uv coordinates, normals,
9597
texture images and material reflectivity values.
9698
@@ -208,6 +210,44 @@ def load_obj(f_obj, load_textures=True):
208210
f_obj.close()
209211

210212

213+
def load_objs_as_meshes(files: list, device=None, load_textures: bool = True):
214+
"""
215+
Load meshes from a list of .obj files using the load_obj function, and
216+
return them as a Meshes object. This only works for meshes which have a
217+
single texture image for the whole mesh. See the load_obj function for more
218+
details. material_colors and normals are not stored.
219+
220+
Args:
221+
f: A list of file-like objects (with methods read, readline, tell,
222+
and seek), pathlib paths or strings containing file names.
223+
device: Desired device of returned Meshes. Default:
224+
uses the current device for the default tensor type.
225+
load_textures: Boolean indicating whether material files are loaded
226+
227+
Returns:
228+
New Meshes object.
229+
"""
230+
mesh_list = []
231+
for f_obj in files:
232+
verts, faces, aux = load_obj(f_obj, load_textures=load_textures)
233+
verts = verts.to(device)
234+
tex = None
235+
tex_maps = aux.texture_images
236+
if tex_maps is not None and len(tex_maps) > 0:
237+
verts_uvs = aux.verts_uvs[None, ...].to(device) # (1, V, 2)
238+
faces_uvs = faces.textures_idx[None, ...].to(device) # (1, F, 3)
239+
image = list(tex_maps.values())[0].to(device)[None]
240+
tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=image)
241+
242+
mesh = Meshes(
243+
verts=[verts], faces=[faces.verts_idx.to(device)], textures=tex
244+
)
245+
mesh_list.append(mesh)
246+
if len(mesh_list) == 1:
247+
return mesh_list[0]
248+
return join_meshes(mesh_list)
249+
250+
211251
def _parse_face(
212252
line,
213253
material_idx,

pytorch3d/structures/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3-
from .meshes import Meshes
3+
from .meshes import Meshes, join_meshes
44
from .textures import Textures
55
from .utils import (
66
list_to_packed,

pytorch3d/structures/meshes.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python3
22
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
33

4+
from typing import List
45
import torch
56

67
from pytorch3d import _C
@@ -1365,3 +1366,77 @@ def extend(self, N: int):
13651366
if self.textures is not None:
13661367
tex = self.textures.extend(N)
13671368
return Meshes(verts=new_verts_list, faces=new_faces_list, textures=tex)
1369+
1370+
1371+
def join_meshes(meshes: List[Meshes], include_textures: bool = True):
1372+
"""
1373+
Merge multiple Meshes objects, i.e. concatenate the meshes objects. They
1374+
must all be on the same device. If include_textures is true, they must all
1375+
be compatible, either all or none having textures, and all the Textures
1376+
objects having the same members. If include_textures is False, textures are
1377+
ignored.
1378+
1379+
Args:
1380+
meshes: list of meshes.
1381+
include_textures: (bool) whether to try to join the textures.
1382+
1383+
Returns:
1384+
new Meshes object containing all the meshes from all the inputs.
1385+
"""
1386+
if isinstance(meshes, Meshes):
1387+
# Meshes objects can be iterated and produce single Meshes. We avoid
1388+
# letting join_meshes(mesh1, mesh2) silently do the wrong thing.
1389+
raise ValueError("Wrong first argument to join_meshes.")
1390+
verts = [v for mesh in meshes for v in mesh.verts_list()]
1391+
faces = [f for mesh in meshes for f in mesh.faces_list()]
1392+
if len(meshes) == 0 or not include_textures:
1393+
return Meshes(verts=verts, faces=faces)
1394+
1395+
if meshes[0].textures is None:
1396+
if any(mesh.textures is not None for mesh in meshes):
1397+
raise ValueError("Inconsistent textures in join_meshes.")
1398+
return Meshes(verts=verts, faces=faces)
1399+
1400+
if any(mesh.textures is None for mesh in meshes):
1401+
raise ValueError("Inconsistent textures in join_meshes.")
1402+
1403+
# Now we know there are multiple meshes and they have textures to merge.
1404+
first = meshes[0].textures
1405+
kwargs = {}
1406+
if first.maps_padded() is not None:
1407+
if any(mesh.textures.maps_padded() is None for mesh in meshes):
1408+
raise ValueError("Inconsistent maps_padded in join_meshes.")
1409+
maps = [m for mesh in meshes for m in mesh.textures.maps_padded()]
1410+
kwargs["maps"] = maps
1411+
elif any(mesh.textures.maps_padded() is not None for mesh in meshes):
1412+
raise ValueError("Inconsistent maps_padded in join_meshes.")
1413+
1414+
if first.verts_uvs_padded() is not None:
1415+
if any(mesh.textures.verts_uvs_padded() is None for mesh in meshes):
1416+
raise ValueError("Inconsistent verts_uvs_padded in join_meshes.")
1417+
uvs = [uv for mesh in meshes for uv in mesh.textures.verts_uvs_list()]
1418+
V = max(uv.shape[0] for uv in uvs)
1419+
kwargs["verts_uvs"] = struct_utils.list_to_padded(uvs, (V, 2), -1)
1420+
elif any(mesh.textures.verts_uvs_padded() is not None for mesh in meshes):
1421+
raise ValueError("Inconsistent verts_uvs_padded in join_meshes.")
1422+
1423+
if first.faces_uvs_padded() is not None:
1424+
if any(mesh.textures.faces_uvs_padded() is None for mesh in meshes):
1425+
raise ValueError("Inconsistent faces_uvs_padded in join_meshes.")
1426+
uvs = [uv for mesh in meshes for uv in mesh.textures.faces_uvs_list()]
1427+
F = max(uv.shape[0] for uv in uvs)
1428+
kwargs["faces_uvs"] = struct_utils.list_to_padded(uvs, (F, 3), -1)
1429+
elif any(mesh.textures.faces_uvs_padded() is not None for mesh in meshes):
1430+
raise ValueError("Inconsistent faces_uvs_padded in join_meshes.")
1431+
1432+
if first.verts_rgb_padded() is not None:
1433+
if any(mesh.textures.verts_rgb_padded() is None for mesh in meshes):
1434+
raise ValueError("Inconsistent verts_rgb_padded in join_meshes.")
1435+
rgb = [i for mesh in meshes for i in mesh.textures.verts_rgb_list()]
1436+
V = max(i.shape[0] for i in rgb)
1437+
kwargs["verts_rgb"] = struct_utils.list_to_padded(rgb, (V, 3))
1438+
elif any(mesh.textures.verts_rgb_padded() is not None for mesh in meshes):
1439+
raise ValueError("Inconsistent verts_rgb_padded in join_meshes.")
1440+
1441+
tex = Textures(**kwargs)
1442+
return Meshes(verts=verts, faces=faces, textures=tex)

tests/test_obj_io.py

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77
from pathlib import Path
88
import torch
99

10-
from pytorch3d.io import load_obj, save_obj
10+
from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj
11+
from pytorch3d.structures import Meshes, Textures, join_meshes
1112

13+
from common_testing import TestCaseMixin
1214

13-
class TestMeshObjIO(unittest.TestCase):
15+
16+
class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
1417
def test_load_obj_simple(self):
1518
obj_file = "\n".join(
1619
[
@@ -517,6 +520,88 @@ def test_load_obj_missing_mtl_noload(self):
517520
self.assertTrue(aux.material_colors is None)
518521
self.assertTrue(aux.texture_images is None)
519522

523+
def test_join_meshes(self):
524+
"""
525+
Test that join_meshes and load_objs_as_meshes are consistent with single
526+
meshes.
527+
"""
528+
529+
def check_triple(mesh, mesh3):
530+
"""
531+
Verify that mesh3 is three copies of mesh.
532+
"""
533+
534+
def check_item(x, y):
535+
self.assertEqual(x is None, y is None)
536+
if x is not None:
537+
self.assertClose(torch.cat([x, x, x]), y)
538+
539+
check_item(mesh.verts_padded(), mesh3.verts_padded())
540+
check_item(mesh.faces_padded(), mesh3.faces_padded())
541+
if mesh.textures is not None:
542+
check_item(
543+
mesh.textures.maps_padded(), mesh3.textures.maps_padded()
544+
)
545+
check_item(
546+
mesh.textures.faces_uvs_padded(),
547+
mesh3.textures.faces_uvs_padded(),
548+
)
549+
check_item(
550+
mesh.textures.verts_uvs_padded(),
551+
mesh3.textures.verts_uvs_padded(),
552+
)
553+
check_item(
554+
mesh.textures.verts_rgb_padded(),
555+
mesh3.textures.verts_rgb_padded(),
556+
)
557+
558+
DATA_DIR = (
559+
Path(__file__).resolve().parent.parent / "docs/tutorials/data"
560+
)
561+
obj_filename = DATA_DIR / "cow_mesh/cow.obj"
562+
563+
mesh = load_objs_as_meshes([obj_filename])
564+
mesh3 = load_objs_as_meshes([obj_filename, obj_filename, obj_filename])
565+
check_triple(mesh, mesh3)
566+
self.assertTupleEqual(
567+
mesh.textures.maps_padded().shape, (1, 1024, 1024, 3)
568+
)
569+
570+
mesh_notex = load_objs_as_meshes([obj_filename], load_textures=False)
571+
mesh3_notex = load_objs_as_meshes(
572+
[obj_filename, obj_filename, obj_filename], load_textures=False
573+
)
574+
check_triple(mesh_notex, mesh3_notex)
575+
self.assertIsNone(mesh_notex.textures)
576+
577+
verts = torch.randn((4, 3), dtype=torch.float32)
578+
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
579+
vert_tex = torch.tensor(
580+
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32
581+
)
582+
tex = Textures(verts_rgb=vert_tex[None, :])
583+
mesh_rgb = Meshes(verts=[verts], faces=[faces], textures=tex)
584+
mesh_rgb3 = join_meshes([mesh_rgb, mesh_rgb, mesh_rgb])
585+
check_triple(mesh_rgb, mesh_rgb3)
586+
587+
teapot_obj = DATA_DIR / "teapot.obj"
588+
mesh_teapot = load_objs_as_meshes([teapot_obj])
589+
teapot_verts, teapot_faces = mesh_teapot.get_mesh_verts_faces(0)
590+
mix_mesh = load_objs_as_meshes(
591+
[obj_filename, teapot_obj], load_textures=False
592+
)
593+
self.assertEqual(len(mix_mesh), 2)
594+
self.assertClose(mix_mesh.verts_list()[0], mesh.verts_list()[0])
595+
self.assertClose(mix_mesh.faces_list()[0], mesh.faces_list()[0])
596+
self.assertClose(mix_mesh.verts_list()[1], teapot_verts)
597+
self.assertClose(mix_mesh.faces_list()[1], teapot_faces)
598+
599+
cow3_tea = join_meshes([mesh3, mesh_teapot], include_textures=False)
600+
self.assertEqual(len(cow3_tea), 4)
601+
check_triple(mesh_notex, cow3_tea[:3])
602+
self.assertClose(cow3_tea.verts_list()[3], mesh_teapot.verts_list()[0])
603+
self.assertClose(cow3_tea.faces_list()[3], mesh_teapot.faces_list()[0])
604+
520605
@staticmethod
521606
def save_obj_with_init(V: int, F: int):
522607
verts_list = torch.tensor(V * [[0.11, 0.22, 0.33]]).view(-1, 3)

tests/test_rendering_meshes.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
from PIL import Image
1313

14-
from pytorch3d.io import load_obj
14+
from pytorch3d.io import load_objs_as_meshes
1515
from pytorch3d.renderer.cameras import (
1616
OpenGLPerspectiveCameras,
1717
look_at_view_transform,
@@ -274,21 +274,7 @@ def test_texture_map(self):
274274
obj_filename = DATA_DIR / "cow_mesh/cow.obj"
275275

276276
# Load mesh + texture
277-
verts, faces, aux = load_obj(obj_filename)
278-
faces_idx = faces.verts_idx.to(device)
279-
verts = verts.to(device)
280-
texture_uvs = aux.verts_uvs
281-
materials = aux.material_colors
282-
tex_maps = aux.texture_images
283-
284-
# tex_maps is a dictionary of material names as keys and texture images
285-
# as values. Only need the images for this example.
286-
textures = Textures(
287-
maps=list(tex_maps.values()),
288-
faces_uvs=faces.textures_idx.to(torch.int64).to(device)[None, :],
289-
verts_uvs=texture_uvs.to(torch.float32).to(device)[None, :],
290-
)
291-
mesh = Meshes(verts=[verts], faces=[faces_idx], textures=textures)
277+
mesh = load_objs_as_meshes([obj_filename], device=device)
292278

293279
# Init rasterizer settings
294280
R, T = look_at_view_transform(2.7, 10, 20)
@@ -333,9 +319,11 @@ def test_texture_map(self):
333319
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
334320

335321
# Check grad exists
336-
verts = verts.clone()
322+
[verts] = mesh.verts_list()
337323
verts.requires_grad = True
338-
mesh = Meshes(verts=[verts], faces=[faces_idx], textures=textures)
339-
images = renderer(mesh)
324+
mesh2 = Meshes(
325+
verts=[verts], faces=mesh.faces_list(), textures=mesh.textures
326+
)
327+
images = renderer(mesh2)
340328
images[0, ...].sum().backward()
341329
self.assertIsNotNone(verts.grad)

0 commit comments

Comments
 (0)