@@ -166,21 +166,25 @@ def get_output_dim(args: DictConfig) -> int:
166
166
def change_resolution (
167
167
self ,
168
168
grid_values : VoxelGridValuesBase ,
169
- epoch : int ,
170
169
* ,
170
+ epoch : Optional [int ] = None ,
171
+ grid_values_with_wanted_resolution : Optional [VoxelGridValuesBase ] = None ,
171
172
mode : str = "linear" ,
172
173
align_corners : bool = True ,
173
174
antialias : bool = False ,
174
175
) -> Tuple [VoxelGridValuesBase , bool ]:
175
176
"""
176
- Changes resolution of tensors in `grid_values` to match the `wanted_resolution`.
177
+ Changes resolution of tensors in `grid_values` to match the
178
+ `grid_values_with_wanted_resolution` or resolution on wanted epoch.
177
179
178
180
Args:
179
181
epoch: current training epoch, used to see if the grid needs regridding
180
182
grid_values: instance of self.values_type which contains
181
183
the voxel grid which will be interpolated to create the new grid
182
184
epoch: epoch which is used to get the resolution of the new
183
185
`grid_values` using `self.resolution_changes`.
186
+ grid_values_with_wanted_resolution: `VoxelGridValuesBase` to whose resolution
187
+ to interpolate grid_values
184
188
align_corners: as for torch.nn.functional.interpolate
185
189
mode: as for torch.nn.functional.interpolate
186
190
'nearest' | 'bicubic' | 'linear' | 'area' | 'nearest-exact'.
@@ -195,8 +199,12 @@ def change_resolution(
195
199
- new voxel grid_values of desired resolution, of type self.values_type
196
200
- True if regridding has happened.
197
201
"""
198
- if epoch not in self .resolution_changes :
199
- return grid_values , False
202
+
203
+ if (epoch is None ) == (grid_values_with_wanted_resolution is None ):
204
+ raise ValueError (
205
+ "Exactly one of `epoch` or "
206
+ "`grid_values_with_wanted_resolution` has to be defined."
207
+ )
200
208
201
209
if mode not in ("nearest" , "bicubic" , "linear" , "area" , "nearest-exact" ):
202
210
raise ValueError (
@@ -219,11 +227,28 @@ def change_individual_resolution(tensor, wanted_resolution):
219
227
recompute_scale_factor = False ,
220
228
)
221
229
222
- wanted_shapes = self .get_shapes (epoch = epoch )
223
- params = {
224
- name : change_individual_resolution (getattr (grid_values , name ), shape [1 :])
225
- for name , shape in wanted_shapes .items ()
226
- }
230
+ if epoch is not None :
231
+ if epoch not in self .resolution_changes :
232
+ return grid_values , False
233
+
234
+ wanted_shapes = self .get_shapes (epoch = epoch )
235
+ params = {
236
+ name : change_individual_resolution (
237
+ getattr (grid_values , name ), shape [1 :]
238
+ )
239
+ for name , shape in wanted_shapes .items ()
240
+ }
241
+ else :
242
+ params = {
243
+ name : (
244
+ change_individual_resolution (
245
+ getattr (grid_values , name ), tensor .shape [2 :]
246
+ )
247
+ if tensor is not None
248
+ else None
249
+ )
250
+ for name , tensor in vars (grid_values_with_wanted_resolution ).items ()
251
+ }
227
252
# pyre-ignore[29]
228
253
return self .values_type (** params ), True
229
254
@@ -239,6 +264,82 @@ def get_align_corners(self) -> bool:
239
264
"""
240
265
return self .align_corners
241
266
267
+ def crop_world (
268
+ self ,
269
+ min_point_world : torch .Tensor ,
270
+ max_point_world : torch .Tensor ,
271
+ grid_values : VoxelGridValuesBase ,
272
+ volume_locator : VolumeLocator ,
273
+ ) -> VoxelGridValuesBase :
274
+ """
275
+ Crops the voxel grid based on minimum and maximum occupied point in
276
+ world coordinates. After cropping all 8 corner points are preserved in
277
+ the voxel grid. This is achieved by preserving all the voxels needed to
278
+ calculate the point.
279
+
280
+ +--------B
281
+ / /|
282
+ / / |
283
+ +--------+ | <==== Bounding box represented by points A and B:
284
+ | | | - B has x, y and z coordinates bigger or equal
285
+ | | + to all other points of the object
286
+ | | / - A has x, y and z coordinates smaller or equal
287
+ | |/ to all other points of the object
288
+ A--------+
289
+
290
+ Args:
291
+ min_point_world: torch.Tensor of shape (3,). Has x, y and z coordinates
292
+ smaller or equal to all other occupied points. Point A from the
293
+ picture above.
294
+ max_point_world: torch.Tensor of shape (3,). Has x, y and z coordinates
295
+ bigger or equal to all other occupied points. Point B from the
296
+ picture above.
297
+ grid_values: instance of self.values_type which contains
298
+ the voxel grid which will be cropped to create the new grid
299
+ volume_locator: VolumeLocator object used to convert world to local
300
+ cordinates
301
+ Returns:
302
+ instance of self.values_type which has volume cropped to desired size.
303
+ """
304
+ min_point_local = volume_locator .world_to_local_coords (min_point_world [None ])[0 ]
305
+ max_point_local = volume_locator .world_to_local_coords (max_point_world [None ])[0 ]
306
+ return self .crop_local (min_point_local , max_point_local , grid_values )
307
+
308
+ def crop_local (
309
+ self ,
310
+ min_point_local : torch .Tensor ,
311
+ max_point_local : torch .Tensor ,
312
+ grid_values : VoxelGridValuesBase ,
313
+ ) -> VoxelGridValuesBase :
314
+ """
315
+ Crops the voxel grid based on minimum and maximum occupied point in local
316
+ coordinates. After cropping both min and max point are preserved in the voxel
317
+ grid. This is achieved by preserving all the voxels needed to calculate the point.
318
+
319
+ +--------B
320
+ / /|
321
+ / / |
322
+ +--------+ | <==== Bounding box represented by points A and B:
323
+ | | | - B has x, y and z coordinates bigger or equal
324
+ | | + to all other points of the object
325
+ | | / - A has x, y and z coordinates smaller or equal
326
+ | |/ to all other points of the object
327
+ A--------+
328
+
329
+ Args:
330
+ min_point_local: torch.Tensor of shape (3,). Has x, y and z coordinates
331
+ smaller or equal to all other occupied points. Point A from the
332
+ picture above. All elements in [-1, 1].
333
+ max_point_local: torch.Tensor of shape (3,). Has x, y and z coordinates
334
+ bigger or equal to all other occupied points. Point B from the
335
+ picture above. All elements in [-1, 1].
336
+ grid_values: instance of self.values_type which contains
337
+ the voxel grid which will be cropped to create the new grid
338
+ Returns:
339
+ instance of self.values_type which has volume cropped to desired size.
340
+ """
341
+ raise NotImplementedError ()
342
+
242
343
243
344
@dataclass
244
345
class FullResolutionVoxelGridValues (VoxelGridValuesBase ):
@@ -288,6 +389,34 @@ def get_shapes(self, epoch: int) -> Dict[str, Tuple]:
288
389
width , height , depth = self .get_resolution (epoch )
289
390
return {"voxel_grid" : (self .n_features , width , height , depth )}
290
391
392
+ # pyre-ignore[14]
393
+ def crop_local (
394
+ self ,
395
+ min_point_local : torch .Tensor ,
396
+ max_point_local : torch .Tensor ,
397
+ grid_values : FullResolutionVoxelGridValues ,
398
+ ) -> FullResolutionVoxelGridValues :
399
+ assert torch .all (min_point_local < max_point_local )
400
+ min_point_local = torch .clamp (min_point_local , - 1 , 1 )
401
+ max_point_local = torch .clamp (max_point_local , - 1 , 1 )
402
+ _ , _ , width , height , depth = grid_values .voxel_grid .shape
403
+ resolution = grid_values .voxel_grid .new_tensor ([width , height , depth ])
404
+ min_point_local01 = (min_point_local + 1 ) / 2
405
+ max_point_local01 = (max_point_local + 1 ) / 2
406
+
407
+ if self .align_corners :
408
+ minx , miny , minz = torch .floor (min_point_local01 * (resolution - 1 )).long ()
409
+ maxx , maxy , maxz = torch .ceil (max_point_local01 * (resolution - 1 )).long ()
410
+ else :
411
+ minx , miny , minz = torch .floor (min_point_local01 * resolution - 0.5 ).long ()
412
+ maxx , maxy , maxz = torch .ceil (max_point_local01 * resolution - 0.5 ).long ()
413
+
414
+ return FullResolutionVoxelGridValues (
415
+ voxel_grid = grid_values .voxel_grid [
416
+ :, :, minx : maxx + 1 , miny : maxy + 1 , minz : maxz + 1
417
+ ]
418
+ )
419
+
291
420
292
421
@dataclass
293
422
class CPFactorizedVoxelGridValues (VoxelGridValuesBase ):
@@ -388,6 +517,37 @@ def get_shapes(self, epoch: int) -> Dict[str, Tuple[int, int]]:
388
517
shape_dict ["basis_matrix" ] = (self .n_components , self .n_features )
389
518
return shape_dict
390
519
520
+ # pyre-ignore[14]
521
+ def crop_local (
522
+ self ,
523
+ min_point_local : torch .Tensor ,
524
+ max_point_local : torch .Tensor ,
525
+ grid_values : CPFactorizedVoxelGridValues ,
526
+ ) -> CPFactorizedVoxelGridValues :
527
+ assert torch .all (min_point_local < max_point_local )
528
+ min_point_local = torch .clamp (min_point_local , - 1 , 1 )
529
+ max_point_local = torch .clamp (max_point_local , - 1 , 1 )
530
+ _ , _ , width = grid_values .vector_components_x .shape
531
+ _ , _ , height = grid_values .vector_components_y .shape
532
+ _ , _ , depth = grid_values .vector_components_z .shape
533
+ resolution = grid_values .vector_components_x .new_tensor ([width , height , depth ])
534
+ min_point_local01 = (min_point_local + 1 ) / 2
535
+ max_point_local01 = (max_point_local + 1 ) / 2
536
+
537
+ if self .align_corners :
538
+ minx , miny , minz = torch .floor (min_point_local01 * (resolution - 1 )).long ()
539
+ maxx , maxy , maxz = torch .ceil (max_point_local01 * (resolution - 1 )).long ()
540
+ else :
541
+ minx , miny , minz = torch .floor (min_point_local01 * resolution - 0.5 ).long ()
542
+ maxx , maxy , maxz = torch .ceil (max_point_local01 * resolution - 0.5 ).long ()
543
+
544
+ return CPFactorizedVoxelGridValues (
545
+ vector_components_x = grid_values .vector_components_x [:, :, minx : maxx + 1 ],
546
+ vector_components_y = grid_values .vector_components_y [:, :, miny : maxy + 1 ],
547
+ vector_components_z = grid_values .vector_components_z [:, :, minz : maxz + 1 ],
548
+ basis_matrix = grid_values .basis_matrix ,
549
+ )
550
+
391
551
392
552
@dataclass
393
553
class VMFactorizedVoxelGridValues (VoxelGridValuesBase ):
@@ -585,6 +745,46 @@ def get_shapes(self, epoch: int) -> Dict[str, Tuple]:
585
745
586
746
return shape_dict
587
747
748
+ # pyre-ignore[14]
749
+ def crop_local (
750
+ self ,
751
+ min_point_local : torch .Tensor ,
752
+ max_point_local : torch .Tensor ,
753
+ grid_values : VMFactorizedVoxelGridValues ,
754
+ ) -> VMFactorizedVoxelGridValues :
755
+ assert torch .all (min_point_local < max_point_local )
756
+ min_point_local = torch .clamp (min_point_local , - 1 , 1 )
757
+ max_point_local = torch .clamp (max_point_local , - 1 , 1 )
758
+ _ , _ , width = grid_values .vector_components_x .shape
759
+ _ , _ , height = grid_values .vector_components_y .shape
760
+ _ , _ , depth = grid_values .vector_components_z .shape
761
+ resolution = grid_values .vector_components_x .new_tensor ([width , height , depth ])
762
+ min_point_local01 = (min_point_local + 1 ) / 2
763
+ max_point_local01 = (max_point_local + 1 ) / 2
764
+
765
+ if self .align_corners :
766
+ minx , miny , minz = torch .floor (min_point_local01 * (resolution - 1 )).long ()
767
+ maxx , maxy , maxz = torch .ceil (max_point_local01 * (resolution - 1 )).long ()
768
+ else :
769
+ minx , miny , minz = torch .floor (min_point_local01 * resolution - 0.5 ).long ()
770
+ maxx , maxy , maxz = torch .ceil (max_point_local01 * resolution - 0.5 ).long ()
771
+
772
+ return VMFactorizedVoxelGridValues (
773
+ vector_components_x = grid_values .vector_components_x [:, :, minx : maxx + 1 ],
774
+ vector_components_y = grid_values .vector_components_y [:, :, miny : maxy + 1 ],
775
+ vector_components_z = grid_values .vector_components_z [:, :, minz : maxz + 1 ],
776
+ matrix_components_xy = grid_values .matrix_components_xy [
777
+ :, :, minx : maxx + 1 , miny : maxy + 1
778
+ ],
779
+ matrix_components_yz = grid_values .matrix_components_yz [
780
+ :, :, miny : maxy + 1 , minz : maxz + 1
781
+ ],
782
+ matrix_components_xz = grid_values .matrix_components_xz [
783
+ :, :, minx : maxx + 1 , minz : maxz + 1
784
+ ],
785
+ basis_matrix = grid_values .basis_matrix ,
786
+ )
787
+
588
788
589
789
# pyre-fixme[13]: Attribute `voxel_grid` is never initialized.
590
790
class VoxelGridModule (Configurable , torch .nn .Module ):
@@ -771,6 +971,37 @@ def get_device(self) -> torch.device:
771
971
# pyre-ignore[29]
772
972
return next (val for val in self .params .values () if val is not None ).device
773
973
974
+ def crop_self (self , min_point : torch .Tensor , max_point : torch .Tensor ) -> None :
975
+ """
976
+ Crops self to only represent points between min_point and max_point (inclusive).
977
+
978
+ Args:
979
+ min_point: torch.Tensor of shape (3,). Has x, y and z coordinates
980
+ smaller or equal to all other occupied points.
981
+ max_point: torch.Tensor of shape (3,). Has x, y and z coordinates
982
+ bigger or equal to all other occupied points.
983
+ Returns:
984
+ nothing
985
+ """
986
+ locator = self ._get_volume_locator ()
987
+ # pyre-fixme[29]: `Union[torch._tensor.Tensor,
988
+ # torch.nn.modules.module.Module]` is not a function.
989
+ old_grid_values = self .voxel_grid .values_type (** self .params )
990
+ new_grid_values = self .voxel_grid .crop_world (
991
+ min_point , max_point , old_grid_values , locator
992
+ )
993
+ grid_values , _ = self .voxel_grid .change_resolution (
994
+ new_grid_values , grid_values_with_wanted_resolution = old_grid_values
995
+ )
996
+ # pyre-ignore [16]
997
+ self .params = torch .nn .ParameterDict (
998
+ {k : v for k , v in vars (grid_values ).items ()}
999
+ )
1000
+ # New center of voxel grid is the middle point between max and min points.
1001
+ self .translation = tuple ((max_point + min_point ) / 2 )
1002
+ # new extents of voxel grid are distances between min and max points
1003
+ self .extents = tuple (max_point - min_point )
1004
+
774
1005
def _get_volume_locator (self ) -> VolumeLocator :
775
1006
"""
776
1007
Returns VolumeLocator calculated from `extents` and `translation` members.
0 commit comments