|
8 | 8 |
|
9 | 9 | from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj
|
10 | 10 | from pytorch3d.structures import Meshes, Textures, join_meshes
|
| 11 | +from pytorch3d.utils import torus |
11 | 12 |
|
12 | 13 | from common_testing import TestCaseMixin
|
13 | 14 |
|
@@ -601,15 +602,42 @@ def check_item(x, y):
|
601 | 602 | self.assertClose(cow3_tea.verts_list()[3], mesh_teapot.verts_list()[0])
|
602 | 603 | self.assertClose(cow3_tea.faces_list()[3], mesh_teapot.faces_list()[0])
|
603 | 604 |
|
| 605 | + @staticmethod |
| 606 | + def _bm_save_obj( |
| 607 | + verts: torch.Tensor, faces: torch.Tensor, decimal_places: int |
| 608 | + ): |
| 609 | + return lambda: save_obj(StringIO(), verts, faces, decimal_places) |
| 610 | + |
| 611 | + @staticmethod |
| 612 | + def _bm_load_obj( |
| 613 | + verts: torch.Tensor, faces: torch.Tensor, decimal_places: int |
| 614 | + ): |
| 615 | + f = StringIO() |
| 616 | + save_obj(f, verts, faces, decimal_places) |
| 617 | + s = f.getvalue() |
| 618 | + # Recreate stream so it's unaffected by how it was created. |
| 619 | + return lambda: load_obj(StringIO(s)) |
| 620 | + |
604 | 621 | @staticmethod
|
605 | 622 | def bm_save_simple_obj_with_init(V: int, F: int):
|
606 |
| - verts_list = torch.tensor(V * [[0.11, 0.22, 0.33]]).view(-1, 3) |
607 |
| - faces_list = torch.tensor(F * [[1, 2, 3]]).view(-1, 3) |
608 |
| - return lambda: save_obj( |
609 |
| - StringIO(), verts_list, faces_list, decimal_places=2 |
610 |
| - ) |
| 623 | + verts = torch.tensor(V * [[0.11, 0.22, 0.33]]).view(-1, 3) |
| 624 | + faces = torch.tensor(F * [[1, 2, 3]]).view(-1, 3) |
| 625 | + return TestMeshObjIO._bm_save_obj(verts, faces, decimal_places=2) |
611 | 626 |
|
612 | 627 | @staticmethod
|
613 | 628 | def bm_load_simple_obj_with_init(V: int, F: int):
|
614 |
| - obj = "\n".join(["v 0.1 0.2 0.3"] * V + ["f 1 2 3"] * F) |
615 |
| - return lambda: load_obj(StringIO(obj)) |
| 629 | + verts = torch.tensor(V * [[0.1, 0.2, 0.3]]).view(-1, 3) |
| 630 | + faces = torch.tensor(F * [[1, 2, 3]]).view(-1, 3) |
| 631 | + return TestMeshObjIO._bm_load_obj(verts, faces, decimal_places=2) |
| 632 | + |
| 633 | + @staticmethod |
| 634 | + def bm_save_complex_obj(N: int): |
| 635 | + meshes = torus(r=0.25, R=1.0, sides=N, rings=2 * N) |
| 636 | + [verts], [faces] = meshes.verts_list(), meshes.faces_list() |
| 637 | + return TestMeshObjIO._bm_save_obj(verts, faces, decimal_places=5) |
| 638 | + |
| 639 | + @staticmethod |
| 640 | + def bm_load_complex_obj(N: int): |
| 641 | + meshes = torus(r=0.25, R=1.0, sides=N, rings=2 * N) |
| 642 | + [verts], [faces] = meshes.verts_list(), meshes.faces_list() |
| 643 | + return TestMeshObjIO._bm_load_obj(verts, faces, decimal_places=5) |
0 commit comments