19
19
from typing import ClassVar , Dict , Optional , Tuple , Type
20
20
21
21
import torch
22
+ from omegaconf import DictConfig
22
23
from pytorch3d .implicitron .tools .config import (
23
24
Configurable ,
24
25
registry ,
@@ -81,12 +82,12 @@ def evaluate_world(
81
82
82
83
Arguments:
83
84
points (torch.Tensor): tensor of points that you want to query
84
- of a form (n_grids, n_points , 3)
85
+ of a form (n_grids, ... , 3)
85
86
grid_values: an object of type Class.values_type which has tensors as
86
87
members which have shapes derived from the get_shapes() method
87
88
locator: a VolumeLocator object
88
89
Returns:
89
- torch.Tensor: shape (n_grids, n_points , n_features)
90
+ torch.Tensor: shape (n_grids, ... , n_features)
90
91
"""
91
92
points_local = locator .world_to_local_coords (points )
92
93
return self .evaluate_local (points_local , grid_values )
@@ -100,11 +101,11 @@ def evaluate_local(
100
101
101
102
Arguments:
102
103
points (torch.Tensor): tensor of points that you want to query
103
- of a form (n_points , 3), in a normalized form (coordinates are in [-1, 1])
104
+ of a form (n_grids, ... , 3), in a normalized form (coordinates are in [-1, 1])
104
105
grid_values: an object of type VMFactorizedVoxelGrid.values_type which has tensors
105
106
as members which have shapes derived from the get_shapes() method
106
107
Returns:
107
- torch.Tensor: shape (n_grids, n_points , n_features)
108
+ torch.Tensor: shape (n_grids, ... , n_features)
108
109
"""
109
110
raise NotImplementedError ()
110
111
@@ -117,11 +118,28 @@ def get_shapes(self) -> Dict[str, Tuple]:
117
118
a dictionary of needed shapes. To use the evaluate_local and evaluate_world methods
118
119
replace the shapes in the dictionary with tensors of those shapes and add the
119
120
first 'batch' dimension. If the required shape is (a, b) and you want to
120
- have g grids than the tensor that replaces the shape should have the
121
+ have g grids then the tensor that replaces the shape should have the
121
122
shape (g, a, b).
122
123
"""
123
124
raise NotImplementedError ()
124
125
126
+ @staticmethod
127
+ def get_output_dim (args : DictConfig ) -> int :
128
+ """
129
+ Given all the arguments of the grid's __init__, returns output's last dimension length.
130
+
131
+ In particular, if self.evaluate_world or self.evaluate_local
132
+ are called with `points` of shape (n_grids, n_points, 3),
133
+ their output will be of shape
134
+ (n_grids, n_points, grid.get_output_dim()).
135
+
136
+ Args:
137
+ args: DictConfig which would be used to initialize the object
138
+ Returns:
139
+ output's last dimension length
140
+ """
141
+ return args ["n_features" ]
142
+
125
143
126
144
@dataclass
127
145
class FullResolutionVoxelGridValues (VoxelGridValuesBase ):
@@ -149,19 +167,23 @@ def evaluate_local(
149
167
150
168
Arguments:
151
169
points (torch.Tensor): tensor of points that you want to query
152
- of a form (n_points , 3), in a normalized form (coordinates are in [-1, 1])
170
+ of a form (... , 3), in a normalized form (coordinates are in [-1, 1])
153
171
grid_values: an object of type values_type which has tensors as
154
172
members which have shapes derived from the get_shapes() method
155
173
Returns:
156
- torch.Tensor: shape (n_grids, n_points , n_features)
174
+ torch.Tensor: shape (n_grids, ... , n_features)
157
175
"""
158
- return interpolate_volume (
176
+ # (n_grids, n_points_total, n_features) from (n_grids, ..., n_features)
177
+ recorded_shape = points .shape
178
+ points = points .view (points .shape [0 ], - 1 , points .shape [- 1 ])
179
+ interpolated = interpolate_volume (
159
180
points ,
160
181
grid_values .voxel_grid ,
161
182
align_corners = self .align_corners ,
162
183
padding_mode = self .padding ,
163
184
mode = self .mode ,
164
185
)
186
+ return interpolated .view (* recorded_shape [:- 1 ], - 1 )
165
187
166
188
def get_shapes (self ) -> Dict [str , Tuple ]:
167
189
return {"voxel_grid" : (self .n_features , * self .resolution )}
@@ -202,10 +224,11 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
202
224
Members:
203
225
n_components: number of vector triplets, higher number gives better approximation.
204
226
matrix_reduction: how to transform components. If matrix_reduction=True result
205
- matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by the
206
- basis_matrix of shape (n_grids, n_components, n_features). If
207
- matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
208
- is summed along the rows to get (n_grids, n_points, 1).
227
+ matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied
228
+ by the basis_matrix of shape (n_grids, n_components, n_features). If
229
+ matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components)
230
+ is summed along the rows to get (n_grids, n_points_total, 1), which is then viewed
231
+ to return to starting shape (n_grids, ..., 1).
209
232
"""
210
233
211
234
# the type of grid_values argument needed to run evaluate_local()
@@ -219,8 +242,8 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
219
242
def evaluate_local (
220
243
self , points : torch .Tensor , grid_values : CPFactorizedVoxelGridValues
221
244
) -> torch .Tensor :
222
- def factor (i ):
223
- axis = [ "x" , "y" , "z" ][ i ]
245
+ def factor (axis ):
246
+ i = { "x" : 0 , "y" : 1 , "z" : 2 }[ axis ]
224
247
index = points [..., i , None ]
225
248
vector = getattr (grid_values , "vector_components_" + axis )
226
249
return interpolate_line (
@@ -231,17 +254,25 @@ def factor(i):
231
254
mode = self .mode ,
232
255
)
233
256
257
+ # (n_grids, n_points_total, n_features) from (n_grids, ..., n_features)
258
+ recorded_shape = points .shape
259
+ points = points .view (points .shape [0 ], - 1 , points .shape [- 1 ])
260
+
234
261
# collect points from all the vectors and multipy them out
235
- mult = factor (0 ) * factor (1 ) * factor (2 )
262
+ mult = factor ("x" ) * factor ("y" ) * factor ("z" )
236
263
237
264
# reduce the result from
238
- # (n_grids, n_points , n_components) to (n_grids, n_points , n_features)
265
+ # (n_grids, n_points_total , n_components) to (n_grids, n_points_total , n_features)
239
266
if grid_values .basis_matrix is not None :
240
- # (n_grids, n_points, n_features) =
241
- # (n_grids, n_points, total_n_components) x (total_n_components, n_features)
242
- return torch .bmm (mult , grid_values .basis_matrix )
243
-
244
- return mult .sum (axis = - 1 , keepdim = True )
267
+ # (n_grids, n_points_total, n_features) =
268
+ # (n_grids, n_points_total, total_n_components) @
269
+ # (n_grids, total_n_components, n_features)
270
+ result = torch .bmm (mult , grid_values .basis_matrix )
271
+ else :
272
+ # (n_grids, n_points_total, 1) from (n_grids, n_points_total, n_features)
273
+ result = mult .sum (axis = - 1 , keepdim = True )
274
+ # (n_grids, ..., n_features)
275
+ return result .view (* recorded_shape [:- 1 ], - 1 )
245
276
246
277
def get_shapes (self ) -> Dict [str , Tuple [int , int ]]:
247
278
if self .matrix_reduction is False and self .n_features != 1 :
@@ -308,10 +339,11 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
308
339
coordinate of a form (n_xy_planes, n_yz_planes, n_xz_planes). You must specify
309
340
either n_components or distribution_of_components, you cannot specify both.
310
341
matrix_reduction: how to transform components. If matrix_reduction=True result
311
- matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by
312
- the basis_matrix of shape (n_grids, n_components, n_features). If
313
- matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
314
- is summed along the rows to get (n_grids, n_points, 1).
342
+ matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied
343
+ by the basis_matrix of shape (n_grids, n_components, n_features). If
344
+ matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components)
345
+ is summed along the rows to get (n_grids, n_points_total, 1), which is then viewed
346
+ to return to starting shape (n_grids, ..., 1).
315
347
"""
316
348
317
349
# the type of grid_values argument needed to run evaluate_local()
@@ -326,6 +358,10 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
326
358
def evaluate_local (
327
359
self , points : torch .Tensor , grid_values : VMFactorizedVoxelGridValues
328
360
) -> torch .Tensor :
361
+ # (n_grids, n_points_total, n_features) from (n_grids, ..., n_features)
362
+ recorded_shape = points .shape
363
+ points = points .view (points .shape [0 ], - 1 , points .shape [- 1 ])
364
+
329
365
# collect points from matrices and vectors and multiply them
330
366
a = interpolate_plane (
331
367
points [..., :2 ],
@@ -375,9 +411,13 @@ def evaluate_local(
375
411
# (n_grids, n_points, n_features) =
376
412
# (n_grids, n_points, total_n_components) x
377
413
# (n_grids, total_n_components, n_features)
378
- return torch .bmm (feats , grid_values .basis_matrix )
379
- # pyre-ignore[28]
380
- return feats .sum (axis = - 1 , keepdim = True )
414
+ result = torch .bmm (feats , grid_values .basis_matrix )
415
+ else :
416
+ # pyre-ignore[28]
417
+ # (n_grids, n_points, 1) from (n_grids, n_points, n_features)
418
+ result = feats .sum (axis = - 1 , keepdim = True )
419
+ # (n_grids, ..., n_features)
420
+ return result .view (* recorded_shape [:- 1 ], - 1 )
381
421
382
422
def get_shapes (self ) -> Dict [str , Tuple ]:
383
423
if self .matrix_reduction is False and self .n_features != 1 :
@@ -494,9 +534,9 @@ def forward(self, points: torch.Tensor) -> torch.Tensor:
494
534
495
535
Args:
496
536
points (torch.Tensor): tensor of points that you want to query
497
- of a form (n_points , 3)
537
+ of a form (... , 3)
498
538
Returns:
499
- torch.Tensor of shape (n_points , n_features)
539
+ torch.Tensor of shape (... , n_features)
500
540
"""
501
541
locator = VolumeLocator (
502
542
batch_size = 1 ,
0 commit comments