Skip to content

Commit e3819a4

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
update rasterizer transform method
Summary: Update the transform method in the mesh rasterizer class to use the new `update_padded` method on the `Meshes` class to directly update the mesh vertices. Also added a benchmark. Reviewed By: gkioxari Differential Revision: D21700352 fbshipit-source-id: c330e4040c681729eb2cc7bdfd92fb4a51a1a7d6
1 parent 1fb97f9 commit e3819a4

File tree

2 files changed

+53
-11
lines changed

2 files changed

+53
-11
lines changed

pytorch3d/renderer/mesh/rasterizer.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -92,23 +92,20 @@ def transform(self, meshes_world, **kwargs) -> torch.Tensor:
9292
msg = "Cameras must be specified either at initialization \
9393
or in the forward pass of MeshRasterizer"
9494
raise ValueError(msg)
95-
9695
verts_world = meshes_world.verts_padded()
97-
verts_world_packed = meshes_world.verts_packed()
98-
verts_screen = cameras.transform_points(verts_world, **kwargs)
9996

10097
# NOTE: Retaining view space z coordinate for now.
10198
# TODO: Revisit whether or not to transform z coordinate to [-1, 1] or
10299
# [0, 1] range.
103-
view_transform = get_world_to_view_transform(R=cameras.R, T=cameras.T)
104-
verts_view = view_transform.transform_points(verts_world)
100+
verts_view = cameras.get_world_to_view_transform(**kwargs).transform_points(
101+
verts_world
102+
)
103+
verts_screen = cameras.get_projection_transform(**kwargs).transform_points(
104+
verts_view
105+
)
105106
verts_screen[..., 2] = verts_view[..., 2]
106-
107-
# Offset verts of input mesh to reuse cached padded/packed calculations.
108-
pad_to_packed_idx = meshes_world.verts_padded_to_packed_idx()
109-
verts_screen_packed = verts_screen.view(-1, 3)[pad_to_packed_idx, :]
110-
verts_packed_offset = verts_screen_packed - verts_world_packed
111-
return meshes_world.offset_verts(verts_packed_offset)
107+
meshes_screen = meshes_world.update_padded(new_verts_padded=verts_screen)
108+
return meshes_screen
112109

113110
def forward(self, meshes_world, **kwargs) -> Fragments:
114111
"""
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
4+
from itertools import product
5+
6+
import torch
7+
from fvcore.common.benchmark import benchmark
8+
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
9+
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer
10+
from pytorch3d.utils.ico_sphere import ico_sphere
11+
12+
13+
def rasterize_transform_with_init(num_meshes: int, ico_level: int = 5, device="cuda"):
14+
# Init meshes
15+
sphere_meshes = ico_sphere(ico_level, device).extend(num_meshes)
16+
# Init transform
17+
R, T = look_at_view_transform(1.0, 0.0, 0.0)
18+
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
19+
# Init rasterizer
20+
rasterizer = MeshRasterizer(cameras=cameras)
21+
22+
torch.cuda.synchronize()
23+
24+
def raster_fn():
25+
rasterizer.transform(sphere_meshes)
26+
torch.cuda.synchronize()
27+
28+
return raster_fn
29+
30+
31+
def bm_mesh_rasterizer_transform() -> None:
32+
if torch.cuda.is_available():
33+
kwargs_list = []
34+
num_meshes = [1, 8]
35+
ico_level = [0, 1, 3, 4]
36+
test_cases = product(num_meshes, ico_level)
37+
for case in test_cases:
38+
n, ic = case
39+
kwargs_list.append({"num_meshes": n, "ico_level": ic})
40+
benchmark(
41+
rasterize_transform_with_init,
42+
"MESH_RASTERIZER",
43+
kwargs_list,
44+
warmup_iters=1,
45+
)

0 commit comments

Comments
 (0)