Skip to content

Commit e053d7c

Browse files
megluyagaofacebook-github-bot
authored andcommitted
Adding join_mesh in pytorch3d.structures.meshes
Summary: Adding a function in pytorch3d.structures.meshes to join multiple meshes into a Meshes object representing a single mesh. The function currently ignores all textures. Reviewed By: nikhilaravi Differential Revision: D21876908 fbshipit-source-id: 448602857e9d3d3f774d18bb4e93076f78329823
1 parent 4b78e95 commit e053d7c

File tree

6 files changed

+132
-3
lines changed

6 files changed

+132
-3
lines changed

pytorch3d/structures/meshes.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3-
from typing import List
3+
from typing import List, Union
44

55
import torch
66

@@ -1539,3 +1539,28 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
15391539

15401540
tex = Textures(**kwargs)
15411541
return Meshes(verts=verts, faces=faces, textures=tex)
1542+
1543+
1544+
def join_mesh(meshes: Union[Meshes, List[Meshes]]) -> Meshes:
1545+
"""
1546+
Joins a batch of meshes in the form of a Meshes object or a list of Meshes
1547+
objects as a single mesh. If the input is a list, the Meshes objects in the list
1548+
must all be on the same device. This version ignores all textures in the input mehses.
1549+
1550+
Args:
1551+
meshes: Meshes object that contains a batch of meshes or a list of Meshes objects
1552+
1553+
Returns:
1554+
new Meshes object containing a single mesh
1555+
"""
1556+
if isinstance(meshes, List):
1557+
meshes = join_meshes_as_batch(meshes, include_textures=False)
1558+
1559+
if len(meshes) == 1:
1560+
return meshes
1561+
verts = meshes.verts_packed() # (sum(V_n), 3)
1562+
# Offset automatically done by faces_packed
1563+
faces = meshes.faces_packed() # (sum(F_n), 3)
1564+
1565+
mesh = Meshes(verts=verts.unsqueeze(0), faces=faces.unsqueeze(0))
1566+
return mesh
25.9 KB
Loading
21.3 KB
Loading
21.2 KB
Loading

tests/test_obj_io.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
_bilinear_interpolation_vectorized,
1515
)
1616
from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch
17+
from pytorch3d.structures.meshes import join_mesh
1718
from pytorch3d.utils import torus
1819

1920

@@ -648,6 +649,42 @@ def check_item(x, y):
648649
self.assertClose(cow3_tea.verts_list()[3], mesh_teapot.verts_list()[0])
649650
self.assertClose(cow3_tea.faces_list()[3], mesh_teapot.faces_list()[0])
650651

652+
def test_join_meshes(self):
653+
"""
654+
Test that join_mesh joins single meshes and the corresponding values are
655+
consistent with the single meshes.
656+
"""
657+
658+
# Load cow mesh.
659+
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
660+
cow_obj = DATA_DIR / "cow_mesh/cow.obj"
661+
662+
cow_mesh = load_objs_as_meshes([cow_obj])
663+
cow_verts, cow_faces = cow_mesh.get_mesh_verts_faces(0)
664+
# Join a batch of three single meshes and check that the values are consistent
665+
# with the individual meshes.
666+
cow_mesh3 = join_mesh([cow_mesh, cow_mesh, cow_mesh])
667+
668+
def check_item(x, y, offset):
669+
self.assertClose(torch.cat([x, x + offset, x + 2 * offset], dim=1), y)
670+
671+
check_item(cow_mesh.verts_padded(), cow_mesh3.verts_padded(), 0)
672+
check_item(cow_mesh.faces_padded(), cow_mesh3.faces_padded(), cow_mesh._V)
673+
674+
# Test the joining of meshes of different sizes.
675+
teapot_obj = DATA_DIR / "teapot.obj"
676+
teapot_mesh = load_objs_as_meshes([teapot_obj])
677+
teapot_verts, teapot_faces = teapot_mesh.get_mesh_verts_faces(0)
678+
679+
mix_mesh = join_mesh([cow_mesh, teapot_mesh])
680+
mix_verts, mix_faces = mix_mesh.get_mesh_verts_faces(0)
681+
self.assertEqual(len(mix_mesh), 1)
682+
683+
self.assertClose(mix_verts[: cow_mesh._V], cow_verts)
684+
self.assertClose(mix_faces[: cow_mesh._F], cow_faces)
685+
self.assertClose(mix_verts[cow_mesh._V :], teapot_verts)
686+
self.assertClose(mix_faces[cow_mesh._F :], teapot_faces + cow_mesh._V)
687+
651688
@staticmethod
652689
def _bm_save_obj(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
653690
return lambda: save_obj(StringIO(), verts, faces, decimal_places)

tests/test_render_meshes.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
TexturedSoftPhongShader,
2727
)
2828
from pytorch3d.renderer.mesh.texturing import Textures
29-
from pytorch3d.structures.meshes import Meshes
29+
from pytorch3d.structures.meshes import Meshes, join_mesh
3030
from pytorch3d.utils.ico_sphere import ico_sphere
3131

3232

@@ -176,7 +176,7 @@ def test_simple_sphere_batched(self):
176176
# Init renderer
177177
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
178178
shaders = {
179-
"phong": HardGouraudShader,
179+
"phong": HardPhongShader,
180180
"gouraud": HardGouraudShader,
181181
"flat": HardFlatShader,
182182
}
@@ -369,3 +369,70 @@ def test_texture_map(self):
369369
)
370370

371371
self.assertClose(rgb, image_ref, atol=0.05)
372+
373+
def test_joined_spheres(self):
374+
"""
375+
Test a list of Meshes can be joined as a single mesh and
376+
the single mesh is rendered correctly with Phong, Gouraud
377+
and Flat Shaders.
378+
"""
379+
device = torch.device("cuda:0")
380+
381+
# Init mesh with vertex textures.
382+
# Initialize a list containing two ico spheres of different sizes.
383+
sphere_list = [ico_sphere(3, device), ico_sphere(4, device)]
384+
# [(42 verts, 80 faces), (162 verts, 320 faces)]
385+
# The scale the vertices need to be set at to resize the spheres
386+
scales = [0.25, 1]
387+
# The distance the spheres ought to be offset horizontally to prevent overlap.
388+
offsets = [1.2, -0.3]
389+
# Initialize a list containing the adjusted sphere meshes.
390+
sphere_mesh_list = []
391+
for i in range(len(sphere_list)):
392+
verts = sphere_list[i].verts_padded() * scales[i]
393+
verts[0, :, 0] += offsets[i]
394+
sphere_mesh_list.append(
395+
Meshes(verts=verts, faces=sphere_list[i].faces_padded())
396+
)
397+
joined_sphere_mesh = join_mesh(sphere_mesh_list)
398+
joined_sphere_mesh.textures = Textures(
399+
verts_rgb=torch.ones_like(joined_sphere_mesh.verts_padded())
400+
)
401+
402+
# Init rasterizer settings
403+
R, T = look_at_view_transform(2.7, 0.0, 0.0)
404+
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
405+
raster_settings = RasterizationSettings(
406+
image_size=512, blur_radius=0.0, faces_per_pixel=1
407+
)
408+
409+
# Init shader settings
410+
materials = Materials(device=device)
411+
lights = PointLights(device=device)
412+
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
413+
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
414+
415+
# Init renderer
416+
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
417+
shaders = {
418+
"phong": HardPhongShader,
419+
"gouraud": HardGouraudShader,
420+
"flat": HardFlatShader,
421+
}
422+
for (name, shader_init) in shaders.items():
423+
shader = shader_init(
424+
lights=lights,
425+
cameras=cameras,
426+
materials=materials,
427+
blend_params=blend_params,
428+
)
429+
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
430+
image = renderer(joined_sphere_mesh)
431+
rgb = image[..., :3].squeeze().cpu()
432+
if DEBUG:
433+
file_name = "DEBUG_joined_spheres_%s.png" % name
434+
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
435+
DATA_DIR / file_name
436+
)
437+
image_ref = load_rgb_image("test_joined_spheres_%s.png" % name, DATA_DIR)
438+
self.assertClose(rgb, image_ref, atol=0.05)

0 commit comments

Comments
 (0)