Skip to content

Commit e6bc960

Browse files
davnov134facebook-github-bot
authored andcommitted
Raysampling
Summary: Implements 3 basic raysamplers. Reviewed By: nikhilaravi Differential Revision: D24110643 fbshipit-source-id: eb67d0e56773c7871ebdcb23e7e520302dc1b3c9
1 parent 1f9cf91 commit e6bc960

File tree

6 files changed

+880
-1
lines changed

6 files changed

+880
-1
lines changed

pytorch3d/renderer/__init__.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,16 @@
2020
look_at_rotation,
2121
look_at_view_transform,
2222
)
23-
from .implicit import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
23+
from .implicit import (
24+
AbsorptionOnlyRaymarcher,
25+
EmissionAbsorptionRaymarcher,
26+
GridRaysampler,
27+
MonteCarloRaysampler,
28+
NDCGridRaysampler,
29+
RayBundle,
30+
ray_bundle_to_ray_points,
31+
ray_bundle_variables_to_ray_points,
32+
)
2433
from .lighting import DirectionalLights, PointLights, diffuse, specular
2534
from .materials import Materials
2635
from .mesh import (
+6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

33
from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
4+
from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
5+
from .utils import (
6+
RayBundle,
7+
ray_bundle_to_ray_points,
8+
ray_bundle_variables_to_ray_points,
9+
)
410

511

612
__all__ = [k for k in globals().keys() if not k.startswith("_")]
+320
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
import torch
3+
4+
from ..cameras import CamerasBase
5+
from .utils import RayBundle
6+
7+
8+
"""
9+
This file defines three raysampling techniques:
10+
- GridRaysampler which can be used to sample rays from pixels of an image grid
11+
- NDCGridRaysampler which can be used to sample rays from pixels of an image grid,
12+
which follows the pytorch3d convention for image grid coordinates
13+
- MonteCarloRaysampler which randomly selects image pixels and emits rays from them
14+
"""
15+
16+
17+
class GridRaysampler(torch.nn.Module):
18+
"""
19+
Samples a fixed number of points along rays which are regulary distributed
20+
in a batch of rectangular image grids. Points along each ray
21+
have uniformly-spaced z-coordinates between a predefined
22+
minimum and maximum depth.
23+
24+
The raysampler first generates a 3D coordinate grid of the following form:
25+
```
26+
/ min_x, min_y, max_depth -------------- / max_x, min_y, max_depth
27+
/ /|
28+
/ / | ^
29+
/ min_depth min_depth / | |
30+
min_x ----------------------------- max_x | | image
31+
min_y min_y | | height
32+
| | | |
33+
| | | v
34+
| | |
35+
| | / max_x, max_y, ^
36+
| | / max_depth /
37+
min_x max_y / / n_pts_per_ray
38+
max_y ----------------------------- max_x/ min_depth v
39+
< --- image_width --- >
40+
```
41+
42+
In order to generate ray points, `GridRaysampler` takes each 3D point of
43+
the grid (with coordinates `[x, y, depth]`) and unprojects it
44+
with `cameras.unproject_points([x, y, depth])`, where `cameras` are an
45+
additional input to the `forward` function.
46+
47+
Note that this is a generic implementation that can support any image grid
48+
coordinate convention. For a raysampler which follows the PyTorch3D
49+
coordinate conventions please refer to `NDCGridRaysampler`.
50+
As such, `NDCGridRaysampler` is a special case of `GridRaysampler`.
51+
"""
52+
53+
def __init__(
54+
self,
55+
min_x: float,
56+
max_x: float,
57+
min_y: float,
58+
max_y: float,
59+
image_width: int,
60+
image_height: int,
61+
n_pts_per_ray: int,
62+
min_depth: float,
63+
max_depth: float,
64+
):
65+
"""
66+
Args:
67+
min_x: The leftmost x-coordinate of each ray's source pixel's center.
68+
max_x: The rightmost x-coordinate of each ray's source pixel's center.
69+
min_y: The topmost y-coordinate of each ray's source pixel's center.
70+
max_y: The bottommost y-coordinate of each ray's source pixel's center.
71+
image_width: The horizontal size of the image grid.
72+
image_height: The vertical size of the image grid.
73+
n_pts_per_ray: The number of points sampled along each ray.
74+
min_depth: The minimum depth of a ray-point.
75+
max_depth: The maximum depth of a ray-point.
76+
"""
77+
super().__init__()
78+
self._n_pts_per_ray = n_pts_per_ray
79+
self._min_depth = min_depth
80+
self._max_depth = max_depth
81+
82+
# get the initial grid of image xy coords
83+
_xy_grid = torch.stack(
84+
tuple(
85+
reversed(
86+
torch.meshgrid(
87+
torch.linspace(min_y, max_y, image_height, dtype=torch.float32),
88+
torch.linspace(min_x, max_x, image_width, dtype=torch.float32),
89+
)
90+
)
91+
),
92+
dim=-1,
93+
)
94+
self.register_buffer("_xy_grid", _xy_grid)
95+
96+
def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle:
97+
"""
98+
Args:
99+
cameras: A batch of `batch_size` cameras from which the rays are emitted.
100+
Returns:
101+
A named tuple RayBundle with the following fields:
102+
origins: A tensor of shape
103+
`(batch_size, image_height, image_width, 3)`
104+
denoting the locations of ray origins in the world coordinates.
105+
directions: A tensor of shape
106+
`(batch_size, image_height, image_width, 3)`
107+
denoting the directions of each ray in the world coordinates.
108+
lengths: A tensor of shape
109+
`(batch_size, image_height, image_width, n_pts_per_ray)`
110+
containing the z-coordinate (=depth) of each ray in world units.
111+
xys: A tensor of shape
112+
`(batch_size, image_height, image_width, 2)`
113+
containing the 2D image coordinates of each ray.
114+
"""
115+
116+
batch_size = cameras.R.shape[0] # pyre-ignore
117+
118+
device = cameras.device
119+
120+
# expand the (H, W, 2) grid batch_size-times to (B, H, W, 2)
121+
xy_grid = self._xy_grid.to(device)[None].expand( # pyre-ignore
122+
batch_size, *self._xy_grid.shape
123+
)
124+
125+
return _xy_to_ray_bundle(
126+
cameras, xy_grid, self._min_depth, self._max_depth, self._n_pts_per_ray
127+
)
128+
129+
130+
class NDCGridRaysampler(GridRaysampler):
131+
"""
132+
Samples a fixed number of points along rays which are regulary distributed
133+
in a batch of rectangular image grids. Points along each ray
134+
have uniformly-spaced z-coordinates between a predefined minimum and maximum depth.
135+
136+
`NDCGridRaysampler` follows the screen conventions of the `Meshes` and `Pointclouds`
137+
renderers. I.e. the border of the leftmost / rightmost / topmost / bottommost pixel
138+
has coordinates 1.0 / -1.0 / 1.0 / -1.0 respectively.
139+
"""
140+
141+
def __init__(
142+
self,
143+
image_width: int,
144+
image_height: int,
145+
n_pts_per_ray: int,
146+
min_depth: float,
147+
max_depth: float,
148+
):
149+
"""
150+
Args:
151+
image_width: The horizontal size of the image grid.
152+
image_height: The vertical size of the image grid.
153+
n_pts_per_ray: The number of points sampled along each ray.
154+
min_depth: The minimum depth of a ray-point.
155+
max_depth: The maximum depth of a ray-point.
156+
"""
157+
half_pix_width = 1.0 / image_width
158+
half_pix_height = 1.0 / image_height
159+
super().__init__(
160+
min_x=1.0 - half_pix_width,
161+
max_x=-1.0 + half_pix_width,
162+
min_y=1.0 - half_pix_height,
163+
max_y=-1.0 + half_pix_height,
164+
image_width=image_width,
165+
image_height=image_height,
166+
n_pts_per_ray=n_pts_per_ray,
167+
min_depth=min_depth,
168+
max_depth=max_depth,
169+
)
170+
171+
172+
class MonteCarloRaysampler(torch.nn.Module):
173+
"""
174+
Samples a fixed number of pixels within denoted xy bounds uniformly at random.
175+
For each pixel, a fixed number of points is sampled along its ray at uniformly-spaced
176+
z-coordinates such that the z-coordinates range between a predefined minimum
177+
and maximum depth.
178+
"""
179+
180+
def __init__(
181+
self,
182+
min_x: float,
183+
max_x: float,
184+
min_y: float,
185+
max_y: float,
186+
n_rays_per_image: int,
187+
n_pts_per_ray: int,
188+
min_depth: float,
189+
max_depth: float,
190+
):
191+
"""
192+
Args:
193+
min_x: The smallest x-coordinate of each ray's source pixel.
194+
max_x: The largest x-coordinate of each ray's source pixel.
195+
min_y: The smallest y-coordinate of each ray's source pixel.
196+
max_y: The largest y-coordinate of each ray's source pixel.
197+
n_rays_per_image: The number of rays randomly sampled in each camera.
198+
n_pts_per_ray: The number of points sampled along each ray.
199+
min_depth: The minimum depth of each ray-point.
200+
max_depth: The maximum depth of each ray-point.
201+
"""
202+
super().__init__()
203+
self._min_x = min_x
204+
self._max_x = max_x
205+
self._min_y = min_y
206+
self._max_y = max_y
207+
self._n_rays_per_image = n_rays_per_image
208+
self._n_pts_per_ray = n_pts_per_ray
209+
self._min_depth = min_depth
210+
self._max_depth = max_depth
211+
212+
def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle:
213+
"""
214+
Args:
215+
cameras: A batch of `batch_size` cameras from which the rays are emitted.
216+
Returns:
217+
A named tuple RayBundle with the following fields:
218+
origins: A tensor of shape
219+
`(batch_size, n_rays_per_image, 3)`
220+
denoting the locations of ray origins in the world coordinates.
221+
directions: A tensor of shape
222+
`(batch_size, n_rays_per_image, 3)`
223+
denoting the directions of each ray in the world coordinates.
224+
lengths: A tensor of shape
225+
`(batch_size, n_rays_per_image, n_pts_per_ray)`
226+
containing the z-coordinate (=depth) of each ray in world units.
227+
xys: A tensor of shape
228+
`(batch_size, n_rays_per_image, 2)`
229+
containing the 2D image coordinates of each ray.
230+
"""
231+
232+
batch_size = cameras.R.shape[0] # pyre-ignore
233+
234+
device = cameras.device
235+
236+
# get the initial grid of image xy coords
237+
# of shape (batch_size, n_rays_per_image, 2)
238+
rays_xy = torch.cat(
239+
[
240+
torch.rand(
241+
size=(batch_size, self._n_rays_per_image, 1),
242+
dtype=torch.float32,
243+
device=device,
244+
)
245+
* (high - low)
246+
+ low
247+
for low, high in (
248+
(self._min_x, self._max_x),
249+
(self._min_y, self._max_y),
250+
)
251+
],
252+
dim=2,
253+
)
254+
255+
return _xy_to_ray_bundle(
256+
cameras, rays_xy, self._min_depth, self._max_depth, self._n_pts_per_ray
257+
)
258+
259+
260+
def _xy_to_ray_bundle(
261+
cameras: CamerasBase,
262+
xy_grid: torch.Tensor,
263+
min_depth: float,
264+
max_depth: float,
265+
n_pts_per_ray: int,
266+
) -> RayBundle:
267+
"""
268+
Extends the `xy_grid` input of shape `(batch_size, ..., 2)` to rays.
269+
This adds to each xy location in the grid a vector of `n_pts_per_ray` depths
270+
uniformly spaced between `min_depth` and `max_depth`.
271+
272+
The extended grid is then unprojected with `cameras` to yield
273+
ray origins, directions and depths.
274+
"""
275+
batch_size = xy_grid.shape[0]
276+
spatial_size = xy_grid.shape[1:-1]
277+
n_rays_per_image = spatial_size.numel() # pyre-ignore
278+
279+
# ray z-coords
280+
depths = torch.linspace(
281+
min_depth, max_depth, n_pts_per_ray, dtype=xy_grid.dtype, device=xy_grid.device
282+
)
283+
rays_zs = depths[None, None].expand(batch_size, n_rays_per_image, n_pts_per_ray)
284+
285+
# make two sets of points at a constant depth=1 and 2
286+
to_unproject = torch.cat(
287+
(
288+
xy_grid.view(batch_size, 1, n_rays_per_image, 2)
289+
.expand(batch_size, 2, n_rays_per_image, 2)
290+
.reshape(batch_size, n_rays_per_image * 2, 2),
291+
torch.cat(
292+
(
293+
xy_grid.new_ones(batch_size, n_rays_per_image, 1), # pyre-ignore
294+
2.0 * xy_grid.new_ones(batch_size, n_rays_per_image, 1),
295+
),
296+
dim=1,
297+
),
298+
),
299+
dim=-1,
300+
)
301+
302+
# unproject the points
303+
unprojected = cameras.unproject_points(to_unproject) # pyre-ignore
304+
305+
# split the two planes back
306+
rays_plane_1_world = unprojected[:, :n_rays_per_image]
307+
rays_plane_2_world = unprojected[:, n_rays_per_image:]
308+
309+
# directions are the differences between the two planes of points
310+
rays_directions_world = rays_plane_2_world - rays_plane_1_world
311+
312+
# origins are given by subtracting the ray directions from the first plane
313+
rays_origins_world = rays_plane_1_world - rays_directions_world
314+
315+
return RayBundle(
316+
rays_origins_world.view(batch_size, *spatial_size, 3),
317+
rays_directions_world.view(batch_size, *spatial_size, 3),
318+
rays_zs.view(batch_size, *spatial_size, n_pts_per_ray),
319+
xy_grid,
320+
)

0 commit comments

Comments
 (0)