|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 2 | + |
| 3 | + |
| 4 | +from itertools import product |
| 5 | + |
| 6 | +import torch |
| 7 | +from fvcore.common.benchmark import benchmark |
| 8 | +from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform |
| 9 | +from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer |
| 10 | +from pytorch3d.utils.ico_sphere import ico_sphere |
| 11 | + |
| 12 | + |
| 13 | +def rasterize_transform_with_init(num_meshes: int, ico_level: int = 5, device="cuda"): |
| 14 | + # Init meshes |
| 15 | + sphere_meshes = ico_sphere(ico_level, device).extend(num_meshes) |
| 16 | + # Init transform |
| 17 | + R, T = look_at_view_transform(1.0, 0.0, 0.0) |
| 18 | + cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) |
| 19 | + # Init rasterizer |
| 20 | + rasterizer = MeshRasterizer(cameras=cameras) |
| 21 | + |
| 22 | + torch.cuda.synchronize() |
| 23 | + |
| 24 | + def raster_fn(): |
| 25 | + rasterizer.transform(sphere_meshes) |
| 26 | + torch.cuda.synchronize() |
| 27 | + |
| 28 | + return raster_fn |
| 29 | + |
| 30 | + |
| 31 | +def bm_mesh_rasterizer_transform() -> None: |
| 32 | + if torch.cuda.is_available(): |
| 33 | + kwargs_list = [] |
| 34 | + num_meshes = [1, 8] |
| 35 | + ico_level = [0, 1, 3, 4] |
| 36 | + test_cases = product(num_meshes, ico_level) |
| 37 | + for case in test_cases: |
| 38 | + n, ic = case |
| 39 | + kwargs_list.append({"num_meshes": n, "ico_level": ic}) |
| 40 | + benchmark( |
| 41 | + rasterize_transform_with_init, |
| 42 | + "MESH_RASTERIZER", |
| 43 | + kwargs_list, |
| 44 | + warmup_iters=1, |
| 45 | + ) |
0 commit comments