Skip to content

Commit ce3da64

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Simplify transforms in point rasterizer
Summary: Update the transform step in the pointcloud rasterizer to use the `update_padded` method on `Pointclouds`. There was an inefficient step using `offset_points` which went via the packed represntation (and required unecessary additional memory). I think this was before the `update_padded` method was added to `Pointclouds`. Reviewed By: gkioxari Differential Revision: D22329166 fbshipit-source-id: 76db8a19654fb2f7807635d4f1c1729debdf3320
1 parent 876bdff commit ce3da64

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

pytorch3d/renderer/points/rasterizer.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,22 +86,18 @@ def transform(self, point_clouds, **kwargs) -> torch.Tensor:
8686
raise ValueError(msg)
8787

8888
pts_world = point_clouds.points_padded()
89-
pts_world_packed = point_clouds.points_packed()
90-
pts_screen = cameras.transform_points(pts_world, **kwargs)
91-
9289
# NOTE: Retaining view space z coordinate for now.
9390
# TODO: Remove this line when the convention for the z coordinate in
9491
# the rasterizer is decided. i.e. retain z in view space or transform
9592
# to a different range.
96-
view_transform = get_world_to_view_transform(R=cameras.R, T=cameras.T)
97-
verts_view = view_transform.transform_points(pts_world)
98-
pts_screen[..., 2] = verts_view[..., 2]
99-
100-
# Offset points of input pointcloud to reuse cached padded/packed calculations.
101-
pad_to_packed_idx = point_clouds.padded_to_packed_idx()
102-
pts_screen_packed = pts_screen.view(-1, 3)[pad_to_packed_idx, :]
103-
pts_packed_offset = pts_screen_packed - pts_world_packed
104-
point_clouds = point_clouds.offset(pts_packed_offset)
93+
pts_view = cameras.get_world_to_view_transform(**kwargs).transform_points(
94+
pts_world
95+
)
96+
pts_screen = cameras.get_projection_transform(**kwargs).transform_points(
97+
pts_view
98+
)
99+
pts_screen[..., 2] = pts_view[..., 2]
100+
point_clouds = point_clouds.update_padded(pts_screen)
105101
return point_clouds
106102

107103
def forward(self, point_clouds, **kwargs) -> PointFragments:

0 commit comments

Comments
 (0)