Skip to content

Commit d91c1d3

Browse files
patricklabatutfacebook-github-bot
authored andcommitted
Add more complex mesh I/O benchmarks
Summary: Add more complex mesh I/O benchmarks: simple yet non-trivial procedural donut mesh Reviewed By: nikhilaravi Differential Revision: D20390726 fbshipit-source-id: b28b7e3a7f1720823c6bd24faabf688bb0127b7d
1 parent 327868b commit d91c1d3

File tree

3 files changed

+95
-16
lines changed

3 files changed

+95
-16
lines changed

tests/bm_mesh_io.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,29 @@ def bm_save_load() -> None:
3636
simple_kwargs_list,
3737
warmup_iters=1,
3838
)
39+
40+
complex_kwargs_list = [{"N": 8}, {"N": 32}, {"N": 128}]
41+
benchmark(
42+
TestMeshObjIO.bm_load_complex_obj,
43+
"LOAD_COMPLEX_OBJ",
44+
complex_kwargs_list,
45+
warmup_iters=1,
46+
)
47+
benchmark(
48+
TestMeshObjIO.bm_save_complex_obj,
49+
"SAVE_COMPLEX_OBJ",
50+
complex_kwargs_list,
51+
warmup_iters=1,
52+
)
53+
benchmark(
54+
TestMeshPlyIO.bm_load_complex_ply,
55+
"LOAD_COMPLEX_PLY",
56+
complex_kwargs_list,
57+
warmup_iters=1,
58+
)
59+
benchmark(
60+
TestMeshPlyIO.bm_save_complex_ply,
61+
"SAVE_COMPLEX_PLY",
62+
complex_kwargs_list,
63+
warmup_iters=1,
64+
)

tests/test_obj_io.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj
1010
from pytorch3d.structures import Meshes, Textures, join_meshes
11+
from pytorch3d.utils import torus
1112

1213
from common_testing import TestCaseMixin
1314

@@ -601,15 +602,42 @@ def check_item(x, y):
601602
self.assertClose(cow3_tea.verts_list()[3], mesh_teapot.verts_list()[0])
602603
self.assertClose(cow3_tea.faces_list()[3], mesh_teapot.faces_list()[0])
603604

605+
@staticmethod
606+
def _bm_save_obj(
607+
verts: torch.Tensor, faces: torch.Tensor, decimal_places: int
608+
):
609+
return lambda: save_obj(StringIO(), verts, faces, decimal_places)
610+
611+
@staticmethod
612+
def _bm_load_obj(
613+
verts: torch.Tensor, faces: torch.Tensor, decimal_places: int
614+
):
615+
f = StringIO()
616+
save_obj(f, verts, faces, decimal_places)
617+
s = f.getvalue()
618+
# Recreate stream so it's unaffected by how it was created.
619+
return lambda: load_obj(StringIO(s))
620+
604621
@staticmethod
605622
def bm_save_simple_obj_with_init(V: int, F: int):
606-
verts_list = torch.tensor(V * [[0.11, 0.22, 0.33]]).view(-1, 3)
607-
faces_list = torch.tensor(F * [[1, 2, 3]]).view(-1, 3)
608-
return lambda: save_obj(
609-
StringIO(), verts_list, faces_list, decimal_places=2
610-
)
623+
verts = torch.tensor(V * [[0.11, 0.22, 0.33]]).view(-1, 3)
624+
faces = torch.tensor(F * [[1, 2, 3]]).view(-1, 3)
625+
return TestMeshObjIO._bm_save_obj(verts, faces, decimal_places=2)
611626

612627
@staticmethod
613628
def bm_load_simple_obj_with_init(V: int, F: int):
614-
obj = "\n".join(["v 0.1 0.2 0.3"] * V + ["f 1 2 3"] * F)
615-
return lambda: load_obj(StringIO(obj))
629+
verts = torch.tensor(V * [[0.1, 0.2, 0.3]]).view(-1, 3)
630+
faces = torch.tensor(F * [[1, 2, 3]]).view(-1, 3)
631+
return TestMeshObjIO._bm_load_obj(verts, faces, decimal_places=2)
632+
633+
@staticmethod
634+
def bm_save_complex_obj(N: int):
635+
meshes = torus(r=0.25, R=1.0, sides=N, rings=2 * N)
636+
[verts], [faces] = meshes.verts_list(), meshes.faces_list()
637+
return TestMeshObjIO._bm_save_obj(verts, faces, decimal_places=5)
638+
639+
@staticmethod
640+
def bm_load_complex_obj(N: int):
641+
meshes = torus(r=0.25, R=1.0, sides=N, rings=2 * N)
642+
[verts], [faces] = meshes.verts_list(), meshes.faces_list()
643+
return TestMeshObjIO._bm_load_obj(verts, faces, decimal_places=5)

tests/test_ply_io.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77

88
from pytorch3d.io.ply_io import _load_ply_raw, load_ply, save_ply
9+
from pytorch3d.utils import torus
910

1011
from common_testing import TestCaseMixin
1112

@@ -407,19 +408,43 @@ def test_bad_ply_syntax(self):
407408
load_ply(StringIO("\n".join(lines2)))
408409

409410
@staticmethod
410-
def bm_save_simple_ply_with_init(V: int, F: int):
411-
verts_list = torch.tensor(V * [[0.11, 0.22, 0.33]]).view(-1, 3)
412-
faces_list = torch.tensor(F * [[0, 1, 2]]).view(-1, 3)
411+
def _bm_save_ply(
412+
verts: torch.Tensor, faces: torch.Tensor, decimal_places: int
413+
):
413414
return lambda: save_ply(
414-
StringIO(), verts_list, faces_list, decimal_places=2
415+
StringIO(), verts, faces, decimal_places=decimal_places
415416
)
416417

418+
@staticmethod
419+
def _bm_load_ply(
420+
verts: torch.Tensor, faces: torch.Tensor, decimal_places: int
421+
):
422+
f = StringIO()
423+
save_ply(f, verts, faces, decimal_places)
424+
s = f.getvalue()
425+
# Recreate stream so it's unaffected by how it was created.
426+
return lambda: load_ply(StringIO(s))
427+
428+
@staticmethod
429+
def bm_save_simple_ply_with_init(V: int, F: int):
430+
verts = torch.tensor(V * [[0.11, 0.22, 0.33]]).view(-1, 3)
431+
faces = torch.tensor(F * [[0, 1, 2]]).view(-1, 3)
432+
return TestMeshPlyIO._bm_save_ply(verts, faces, decimal_places=2)
433+
417434
@staticmethod
418435
def bm_load_simple_ply_with_init(V: int, F: int):
419436
verts = torch.tensor([[0.1, 0.2, 0.3]]).expand(V, 3)
420437
faces = torch.tensor([[0, 1, 2]], dtype=torch.int64).expand(F, 3)
421-
ply_file = StringIO()
422-
save_ply(ply_file, verts=verts, faces=faces)
423-
ply = ply_file.getvalue()
424-
# Recreate stream so it's unaffected by how it was created.
425-
return lambda: load_ply(StringIO(ply))
438+
return TestMeshPlyIO._bm_load_ply(verts, faces, decimal_places=2)
439+
440+
@staticmethod
441+
def bm_save_complex_ply(N: int):
442+
meshes = torus(r=0.25, R=1.0, sides=N, rings=2 * N)
443+
[verts], [faces] = meshes.verts_list(), meshes.faces_list()
444+
return TestMeshPlyIO._bm_save_ply(verts, faces, decimal_places=5)
445+
446+
@staticmethod
447+
def bm_load_complex_ply(N: int):
448+
meshes = torus(r=0.25, R=1.0, sides=N, rings=2 * N)
449+
[verts], [faces] = meshes.verts_list(), meshes.faces_list()
450+
return TestMeshPlyIO._bm_load_ply(verts, faces, decimal_places=5)

0 commit comments

Comments
 (0)