Skip to content

Commit db3c12a

Browse files
Darijan Gudeljfacebook-github-bot
authored andcommitted
arbitrary shape input to voxel_grids
Summary: Add the ability to process arbitrary point shapes `[n_grids, ..., 3]` instead of only `[n_grids, n_points, 3]`. Reviewed By: bottler Differential Revision: D39574373 fbshipit-source-id: 0a9ecafe9ea58cd8f909644de43a1185ecf934f4
1 parent 6ae6ff9 commit db3c12a

File tree

2 files changed

+93
-55
lines changed

2 files changed

+93
-55
lines changed

pytorch3d/implicitron/models/implicit_function/voxel_grid.py

Lines changed: 70 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import ClassVar, Dict, Optional, Tuple, Type
2020

2121
import torch
22+
from omegaconf import DictConfig
2223
from pytorch3d.implicitron.tools.config import (
2324
Configurable,
2425
registry,
@@ -81,12 +82,12 @@ def evaluate_world(
8182
8283
Arguments:
8384
points (torch.Tensor): tensor of points that you want to query
84-
of a form (n_grids, n_points, 3)
85+
of a form (n_grids, ..., 3)
8586
grid_values: an object of type Class.values_type which has tensors as
8687
members which have shapes derived from the get_shapes() method
8788
locator: a VolumeLocator object
8889
Returns:
89-
torch.Tensor: shape (n_grids, n_points, n_features)
90+
torch.Tensor: shape (n_grids, ..., n_features)
9091
"""
9192
points_local = locator.world_to_local_coords(points)
9293
return self.evaluate_local(points_local, grid_values)
@@ -100,11 +101,11 @@ def evaluate_local(
100101
101102
Arguments:
102103
points (torch.Tensor): tensor of points that you want to query
103-
of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1])
104+
of a form (n_grids, ..., 3), in a normalized form (coordinates are in [-1, 1])
104105
grid_values: an object of type VMFactorizedVoxelGrid.values_type which has tensors
105106
as members which have shapes derived from the get_shapes() method
106107
Returns:
107-
torch.Tensor: shape (n_grids, n_points, n_features)
108+
torch.Tensor: shape (n_grids, ..., n_features)
108109
"""
109110
raise NotImplementedError()
110111

@@ -117,11 +118,28 @@ def get_shapes(self) -> Dict[str, Tuple]:
117118
a dictionary of needed shapes. To use the evaluate_local and evaluate_world methods
118119
replace the shapes in the dictionary with tensors of those shapes and add the
119120
first 'batch' dimension. If the required shape is (a, b) and you want to
120-
have g grids than the tensor that replaces the shape should have the
121+
have g grids then the tensor that replaces the shape should have the
121122
shape (g, a, b).
122123
"""
123124
raise NotImplementedError()
124125

126+
@staticmethod
127+
def get_output_dim(args: DictConfig) -> int:
128+
"""
129+
Given all the arguments of the grid's __init__, returns output's last dimension length.
130+
131+
In particular, if self.evaluate_world or self.evaluate_local
132+
are called with `points` of shape (n_grids, n_points, 3),
133+
their output will be of shape
134+
(n_grids, n_points, grid.get_output_dim()).
135+
136+
Args:
137+
args: DictConfig which would be used to initialize the object
138+
Returns:
139+
output's last dimension length
140+
"""
141+
return args["n_features"]
142+
125143

126144
@dataclass
127145
class FullResolutionVoxelGridValues(VoxelGridValuesBase):
@@ -149,19 +167,23 @@ def evaluate_local(
149167
150168
Arguments:
151169
points (torch.Tensor): tensor of points that you want to query
152-
of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1])
170+
of a form (..., 3), in a normalized form (coordinates are in [-1, 1])
153171
grid_values: an object of type values_type which has tensors as
154172
members which have shapes derived from the get_shapes() method
155173
Returns:
156-
torch.Tensor: shape (n_grids, n_points, n_features)
174+
torch.Tensor: shape (n_grids, ..., n_features)
157175
"""
158-
return interpolate_volume(
176+
# (n_grids, n_points_total, n_features) from (n_grids, ..., n_features)
177+
recorded_shape = points.shape
178+
points = points.view(points.shape[0], -1, points.shape[-1])
179+
interpolated = interpolate_volume(
159180
points,
160181
grid_values.voxel_grid,
161182
align_corners=self.align_corners,
162183
padding_mode=self.padding,
163184
mode=self.mode,
164185
)
186+
return interpolated.view(*recorded_shape[:-1], -1)
165187

166188
def get_shapes(self) -> Dict[str, Tuple]:
167189
return {"voxel_grid": (self.n_features, *self.resolution)}
@@ -202,10 +224,11 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
202224
Members:
203225
n_components: number of vector triplets, higher number gives better approximation.
204226
matrix_reduction: how to transform components. If matrix_reduction=True result
205-
matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by the
206-
basis_matrix of shape (n_grids, n_components, n_features). If
207-
matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
208-
is summed along the rows to get (n_grids, n_points, 1).
227+
matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied
228+
by the basis_matrix of shape (n_grids, n_components, n_features). If
229+
matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components)
230+
is summed along the rows to get (n_grids, n_points_total, 1), which is then viewed
231+
to return to starting shape (n_grids, ..., 1).
209232
"""
210233

211234
# the type of grid_values argument needed to run evaluate_local()
@@ -219,8 +242,8 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
219242
def evaluate_local(
220243
self, points: torch.Tensor, grid_values: CPFactorizedVoxelGridValues
221244
) -> torch.Tensor:
222-
def factor(i):
223-
axis = ["x", "y", "z"][i]
245+
def factor(axis):
246+
i = {"x": 0, "y": 1, "z": 2}[axis]
224247
index = points[..., i, None]
225248
vector = getattr(grid_values, "vector_components_" + axis)
226249
return interpolate_line(
@@ -231,17 +254,25 @@ def factor(i):
231254
mode=self.mode,
232255
)
233256

257+
# (n_grids, n_points_total, n_features) from (n_grids, ..., n_features)
258+
recorded_shape = points.shape
259+
points = points.view(points.shape[0], -1, points.shape[-1])
260+
234261
# collect points from all the vectors and multipy them out
235-
mult = factor(0) * factor(1) * factor(2)
262+
mult = factor("x") * factor("y") * factor("z")
236263

237264
# reduce the result from
238-
# (n_grids, n_points, n_components) to (n_grids, n_points, n_features)
265+
# (n_grids, n_points_total, n_components) to (n_grids, n_points_total, n_features)
239266
if grid_values.basis_matrix is not None:
240-
# (n_grids, n_points, n_features) =
241-
# (n_grids, n_points, total_n_components) x (total_n_components, n_features)
242-
return torch.bmm(mult, grid_values.basis_matrix)
243-
244-
return mult.sum(axis=-1, keepdim=True)
267+
# (n_grids, n_points_total, n_features) =
268+
# (n_grids, n_points_total, total_n_components) @
269+
# (n_grids, total_n_components, n_features)
270+
result = torch.bmm(mult, grid_values.basis_matrix)
271+
else:
272+
# (n_grids, n_points_total, 1) from (n_grids, n_points_total, n_features)
273+
result = mult.sum(axis=-1, keepdim=True)
274+
# (n_grids, ..., n_features)
275+
return result.view(*recorded_shape[:-1], -1)
245276

246277
def get_shapes(self) -> Dict[str, Tuple[int, int]]:
247278
if self.matrix_reduction is False and self.n_features != 1:
@@ -308,10 +339,11 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
308339
coordinate of a form (n_xy_planes, n_yz_planes, n_xz_planes). You must specify
309340
either n_components or distribution_of_components, you cannot specify both.
310341
matrix_reduction: how to transform components. If matrix_reduction=True result
311-
matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by
312-
the basis_matrix of shape (n_grids, n_components, n_features). If
313-
matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
314-
is summed along the rows to get (n_grids, n_points, 1).
342+
matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied
343+
by the basis_matrix of shape (n_grids, n_components, n_features). If
344+
matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components)
345+
is summed along the rows to get (n_grids, n_points_total, 1), which is then viewed
346+
to return to starting shape (n_grids, ..., 1).
315347
"""
316348

317349
# the type of grid_values argument needed to run evaluate_local()
@@ -326,6 +358,10 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
326358
def evaluate_local(
327359
self, points: torch.Tensor, grid_values: VMFactorizedVoxelGridValues
328360
) -> torch.Tensor:
361+
# (n_grids, n_points_total, n_features) from (n_grids, ..., n_features)
362+
recorded_shape = points.shape
363+
points = points.view(points.shape[0], -1, points.shape[-1])
364+
329365
# collect points from matrices and vectors and multiply them
330366
a = interpolate_plane(
331367
points[..., :2],
@@ -375,9 +411,13 @@ def evaluate_local(
375411
# (n_grids, n_points, n_features) =
376412
# (n_grids, n_points, total_n_components) x
377413
# (n_grids, total_n_components, n_features)
378-
return torch.bmm(feats, grid_values.basis_matrix)
379-
# pyre-ignore[28]
380-
return feats.sum(axis=-1, keepdim=True)
414+
result = torch.bmm(feats, grid_values.basis_matrix)
415+
else:
416+
# pyre-ignore[28]
417+
# (n_grids, n_points, 1) from (n_grids, n_points, n_features)
418+
result = feats.sum(axis=-1, keepdim=True)
419+
# (n_grids, ..., n_features)
420+
return result.view(*recorded_shape[:-1], -1)
381421

382422
def get_shapes(self) -> Dict[str, Tuple]:
383423
if self.matrix_reduction is False and self.n_features != 1:
@@ -494,9 +534,9 @@ def forward(self, points: torch.Tensor) -> torch.Tensor:
494534
495535
Args:
496536
points (torch.Tensor): tensor of points that you want to query
497-
of a form (n_points, 3)
537+
of a form (..., 3)
498538
Returns:
499-
torch.Tensor of shape (n_points, n_features)
539+
torch.Tensor of shape (..., n_features)
500540
"""
501541
locator = VolumeLocator(
502542
batch_size=1,

tests/implicitron/test_voxel_grids.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,17 @@ def test_my_code(self):
3838
return
3939

4040
def get_random_normalized_points(
41-
self, n_grids, n_points, dimension=3
41+
self, n_grids, n_points=None, dimension=3
4242
) -> torch.Tensor:
43+
middle_shape = torch.randint(1, 4, tuple(torch.randint(1, 5, size=(1,))))
4344
# create random query points
44-
return torch.rand(n_grids, n_points, dimension) * 2 - 1
45+
return (
46+
torch.rand(
47+
n_grids, *(middle_shape if n_points is None else [n_points]), dimension
48+
)
49+
* 2
50+
- 1
51+
)
4552

4653
def _test_query_with_constant_init_cp(
4754
self,
@@ -50,7 +57,6 @@ def _test_query_with_constant_init_cp(
5057
n_components: int,
5158
resolution: Tuple[int],
5259
value: float = 1,
53-
n_points: int = 1,
5460
) -> None:
5561
# set everything to 'value' and do query for elementsthe result should
5662
# be of shape (n_grids, n_points, n_features) and be filled with n_components
@@ -65,12 +71,11 @@ def _test_query_with_constant_init_cp(
6571
params = grid.values_type(
6672
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
6773
)
68-
74+
points = self.get_random_normalized_points(n_grids)
6975
assert torch.allclose(
70-
grid.evaluate_local(
71-
self.get_random_normalized_points(n_grids, n_points), params
72-
),
73-
torch.ones(n_grids, n_points, n_features) * n_components * value,
76+
grid.evaluate_local(points, params),
77+
torch.ones(n_grids, *points.shape[1:-1], n_features) * n_components * value,
78+
rtol=0.0001,
7479
)
7580

7681
def _test_query_with_constant_init_vm(
@@ -98,11 +103,10 @@ def _test_query_with_constant_init_vm(
98103
expected_element = (
99104
n_components * value if distribution is None else sum(distribution) * value
100105
)
106+
points = self.get_random_normalized_points(n_grids)
101107
assert torch.allclose(
102-
grid.evaluate_local(
103-
self.get_random_normalized_points(n_grids, n_points), params
104-
),
105-
torch.ones(n_grids, n_points, n_features) * expected_element,
108+
grid.evaluate_local(points, params),
109+
torch.ones(n_grids, *points.shape[1:-1], n_features) * expected_element,
106110
)
107111

108112
def _test_query_with_constant_init_full(
@@ -121,53 +125,48 @@ def _test_query_with_constant_init_full(
121125
)
122126

123127
expected_element = value
128+
points = self.get_random_normalized_points(n_grids)
124129
assert torch.allclose(
125-
grid.evaluate_local(
126-
self.get_random_normalized_points(n_grids, n_points), params
127-
),
128-
torch.ones(n_grids, n_points, n_features) * expected_element,
130+
grid.evaluate_local(points, params),
131+
torch.ones(n_grids, *points.shape[1:-1], n_features) * expected_element,
129132
)
130133

131134
def test_query_with_constant_init(self):
132135
with self.subTest("Full"):
133136
self._test_query_with_constant_init_full(
134-
n_grids=5, n_features=6, resolution=(3, 4, 5), n_points=3
137+
n_grids=5, n_features=6, resolution=(3, 4, 5)
135138
)
136139
with self.subTest("Full with 1 in dimensions"):
137140
self._test_query_with_constant_init_full(
138-
n_grids=5, n_features=1, resolution=(33, 41, 1), n_points=4
141+
n_grids=5, n_features=1, resolution=(33, 41, 1)
139142
)
140143
with self.subTest("CP"):
141144
self._test_query_with_constant_init_cp(
142145
n_grids=5,
143146
n_features=6,
144147
n_components=7,
145148
resolution=(3, 4, 5),
146-
n_points=2,
147149
)
148150
with self.subTest("CP with 1 in dimensions"):
149151
self._test_query_with_constant_init_cp(
150152
n_grids=2,
151153
n_features=1,
152154
n_components=3,
153155
resolution=(3, 1, 1),
154-
n_points=4,
155156
)
156157
with self.subTest("VM with symetric distribution"):
157158
self._test_query_with_constant_init_vm(
158159
n_grids=6,
159160
n_features=9,
160161
resolution=(2, 12, 2),
161162
n_components=12,
162-
n_points=3,
163163
)
164164
with self.subTest("VM with distribution"):
165165
self._test_query_with_constant_init_vm(
166166
n_grids=5,
167167
n_features=1,
168168
resolution=(5, 9, 7),
169169
distribution=(33, 41, 1),
170-
n_points=7,
171170
)
172171

173172
def test_query_with_zero_init(self):
@@ -177,7 +176,6 @@ def test_query_with_zero_init(self):
177176
n_features=6,
178177
n_components=7,
179178
resolution=(3, 2, 5),
180-
n_points=3,
181179
value=0,
182180
)
183181
with self.subTest("Query testing with zero init VMFactorizedVoxelGrid"):
@@ -186,12 +184,11 @@ def test_query_with_zero_init(self):
186184
n_features=9,
187185
resolution=(2, 11, 3),
188186
n_components=3,
189-
n_points=3,
190187
value=0,
191188
)
192189
with self.subTest("Query testing with zero init FullResolutionVoxelGrid"):
193190
self._test_query_with_constant_init_full(
194-
n_grids=4, n_features=2, resolution=(3, 3, 5), n_points=3, value=0
191+
n_grids=4, n_features=2, resolution=(3, 3, 5), value=0
195192
)
196193

197194
def setUp(self):
@@ -324,6 +321,7 @@ def test_interpolation(self):
324321
padding_mode="zeros",
325322
mode="bilinear",
326323
),
324+
rtol=0.0001,
327325
)
328326

329327
def test_floating_point_query(self):

0 commit comments

Comments
 (0)