Skip to content

Commit f55d37f

Browse files
Darijan Gudeljfacebook-github-bot
authored andcommitted
volume cropping
Summary: TensoRF at step 2000 does volume croping and resizing. At those steps it calculates part of the voxel grid which has density big enough to have objects and resizes the grid to fit that object. Change is done on 3 levels: - implicit function subscribes to epochs and at specific epochs finds the bounding box of the object and calls resizing of the color and density voxel grids to fit it - VoxelGrid module calls cropping of the underlaying voxel grid and resizing to fit previous size it also adjusts its extends and translation to match wanted size - Each voxel grid has its own way of cropping the underlaying data Reviewed By: kjchalup Differential Revision: D39854548 fbshipit-source-id: 5435b6e599aef1eaab980f5421d3369ee4829c50
1 parent 0b5def5 commit f55d37f

File tree

2 files changed

+377
-11
lines changed

2 files changed

+377
-11
lines changed

pytorch3d/implicitron/models/implicit_function/voxel_grid.py

Lines changed: 240 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -166,21 +166,25 @@ def get_output_dim(args: DictConfig) -> int:
166166
def change_resolution(
167167
self,
168168
grid_values: VoxelGridValuesBase,
169-
epoch: int,
170169
*,
170+
epoch: Optional[int] = None,
171+
grid_values_with_wanted_resolution: Optional[VoxelGridValuesBase] = None,
171172
mode: str = "linear",
172173
align_corners: bool = True,
173174
antialias: bool = False,
174175
) -> Tuple[VoxelGridValuesBase, bool]:
175176
"""
176-
Changes resolution of tensors in `grid_values` to match the `wanted_resolution`.
177+
Changes resolution of tensors in `grid_values` to match the
178+
`grid_values_with_wanted_resolution` or resolution on wanted epoch.
177179
178180
Args:
179181
epoch: current training epoch, used to see if the grid needs regridding
180182
grid_values: instance of self.values_type which contains
181183
the voxel grid which will be interpolated to create the new grid
182184
epoch: epoch which is used to get the resolution of the new
183185
`grid_values` using `self.resolution_changes`.
186+
grid_values_with_wanted_resolution: `VoxelGridValuesBase` to whose resolution
187+
to interpolate grid_values
184188
align_corners: as for torch.nn.functional.interpolate
185189
mode: as for torch.nn.functional.interpolate
186190
'nearest' | 'bicubic' | 'linear' | 'area' | 'nearest-exact'.
@@ -195,8 +199,12 @@ def change_resolution(
195199
- new voxel grid_values of desired resolution, of type self.values_type
196200
- True if regridding has happened.
197201
"""
198-
if epoch not in self.resolution_changes:
199-
return grid_values, False
202+
203+
if (epoch is None) == (grid_values_with_wanted_resolution is None):
204+
raise ValueError(
205+
"Exactly one of `epoch` or "
206+
"`grid_values_with_wanted_resolution` has to be defined."
207+
)
200208

201209
if mode not in ("nearest", "bicubic", "linear", "area", "nearest-exact"):
202210
raise ValueError(
@@ -219,11 +227,28 @@ def change_individual_resolution(tensor, wanted_resolution):
219227
recompute_scale_factor=False,
220228
)
221229

222-
wanted_shapes = self.get_shapes(epoch=epoch)
223-
params = {
224-
name: change_individual_resolution(getattr(grid_values, name), shape[1:])
225-
for name, shape in wanted_shapes.items()
226-
}
230+
if epoch is not None:
231+
if epoch not in self.resolution_changes:
232+
return grid_values, False
233+
234+
wanted_shapes = self.get_shapes(epoch=epoch)
235+
params = {
236+
name: change_individual_resolution(
237+
getattr(grid_values, name), shape[1:]
238+
)
239+
for name, shape in wanted_shapes.items()
240+
}
241+
else:
242+
params = {
243+
name: (
244+
change_individual_resolution(
245+
getattr(grid_values, name), tensor.shape[2:]
246+
)
247+
if tensor is not None
248+
else None
249+
)
250+
for name, tensor in vars(grid_values_with_wanted_resolution).items()
251+
}
227252
# pyre-ignore[29]
228253
return self.values_type(**params), True
229254

@@ -239,6 +264,82 @@ def get_align_corners(self) -> bool:
239264
"""
240265
return self.align_corners
241266

267+
def crop_world(
268+
self,
269+
min_point_world: torch.Tensor,
270+
max_point_world: torch.Tensor,
271+
grid_values: VoxelGridValuesBase,
272+
volume_locator: VolumeLocator,
273+
) -> VoxelGridValuesBase:
274+
"""
275+
Crops the voxel grid based on minimum and maximum occupied point in
276+
world coordinates. After cropping all 8 corner points are preserved in
277+
the voxel grid. This is achieved by preserving all the voxels needed to
278+
calculate the point.
279+
280+
+--------B
281+
/ /|
282+
/ / |
283+
+--------+ | <==== Bounding box represented by points A and B:
284+
| | | - B has x, y and z coordinates bigger or equal
285+
| | + to all other points of the object
286+
| | / - A has x, y and z coordinates smaller or equal
287+
| |/ to all other points of the object
288+
A--------+
289+
290+
Args:
291+
min_point_world: torch.Tensor of shape (3,). Has x, y and z coordinates
292+
smaller or equal to all other occupied points. Point A from the
293+
picture above.
294+
max_point_world: torch.Tensor of shape (3,). Has x, y and z coordinates
295+
bigger or equal to all other occupied points. Point B from the
296+
picture above.
297+
grid_values: instance of self.values_type which contains
298+
the voxel grid which will be cropped to create the new grid
299+
volume_locator: VolumeLocator object used to convert world to local
300+
cordinates
301+
Returns:
302+
instance of self.values_type which has volume cropped to desired size.
303+
"""
304+
min_point_local = volume_locator.world_to_local_coords(min_point_world[None])[0]
305+
max_point_local = volume_locator.world_to_local_coords(max_point_world[None])[0]
306+
return self.crop_local(min_point_local, max_point_local, grid_values)
307+
308+
def crop_local(
309+
self,
310+
min_point_local: torch.Tensor,
311+
max_point_local: torch.Tensor,
312+
grid_values: VoxelGridValuesBase,
313+
) -> VoxelGridValuesBase:
314+
"""
315+
Crops the voxel grid based on minimum and maximum occupied point in local
316+
coordinates. After cropping both min and max point are preserved in the voxel
317+
grid. This is achieved by preserving all the voxels needed to calculate the point.
318+
319+
+--------B
320+
/ /|
321+
/ / |
322+
+--------+ | <==== Bounding box represented by points A and B:
323+
| | | - B has x, y and z coordinates bigger or equal
324+
| | + to all other points of the object
325+
| | / - A has x, y and z coordinates smaller or equal
326+
| |/ to all other points of the object
327+
A--------+
328+
329+
Args:
330+
min_point_local: torch.Tensor of shape (3,). Has x, y and z coordinates
331+
smaller or equal to all other occupied points. Point A from the
332+
picture above. All elements in [-1, 1].
333+
max_point_local: torch.Tensor of shape (3,). Has x, y and z coordinates
334+
bigger or equal to all other occupied points. Point B from the
335+
picture above. All elements in [-1, 1].
336+
grid_values: instance of self.values_type which contains
337+
the voxel grid which will be cropped to create the new grid
338+
Returns:
339+
instance of self.values_type which has volume cropped to desired size.
340+
"""
341+
raise NotImplementedError()
342+
242343

243344
@dataclass
244345
class FullResolutionVoxelGridValues(VoxelGridValuesBase):
@@ -288,6 +389,34 @@ def get_shapes(self, epoch: int) -> Dict[str, Tuple]:
288389
width, height, depth = self.get_resolution(epoch)
289390
return {"voxel_grid": (self.n_features, width, height, depth)}
290391

392+
# pyre-ignore[14]
393+
def crop_local(
394+
self,
395+
min_point_local: torch.Tensor,
396+
max_point_local: torch.Tensor,
397+
grid_values: FullResolutionVoxelGridValues,
398+
) -> FullResolutionVoxelGridValues:
399+
assert torch.all(min_point_local < max_point_local)
400+
min_point_local = torch.clamp(min_point_local, -1, 1)
401+
max_point_local = torch.clamp(max_point_local, -1, 1)
402+
_, _, width, height, depth = grid_values.voxel_grid.shape
403+
resolution = grid_values.voxel_grid.new_tensor([width, height, depth])
404+
min_point_local01 = (min_point_local + 1) / 2
405+
max_point_local01 = (max_point_local + 1) / 2
406+
407+
if self.align_corners:
408+
minx, miny, minz = torch.floor(min_point_local01 * (resolution - 1)).long()
409+
maxx, maxy, maxz = torch.ceil(max_point_local01 * (resolution - 1)).long()
410+
else:
411+
minx, miny, minz = torch.floor(min_point_local01 * resolution - 0.5).long()
412+
maxx, maxy, maxz = torch.ceil(max_point_local01 * resolution - 0.5).long()
413+
414+
return FullResolutionVoxelGridValues(
415+
voxel_grid=grid_values.voxel_grid[
416+
:, :, minx : maxx + 1, miny : maxy + 1, minz : maxz + 1
417+
]
418+
)
419+
291420

292421
@dataclass
293422
class CPFactorizedVoxelGridValues(VoxelGridValuesBase):
@@ -388,6 +517,37 @@ def get_shapes(self, epoch: int) -> Dict[str, Tuple[int, int]]:
388517
shape_dict["basis_matrix"] = (self.n_components, self.n_features)
389518
return shape_dict
390519

520+
# pyre-ignore[14]
521+
def crop_local(
522+
self,
523+
min_point_local: torch.Tensor,
524+
max_point_local: torch.Tensor,
525+
grid_values: CPFactorizedVoxelGridValues,
526+
) -> CPFactorizedVoxelGridValues:
527+
assert torch.all(min_point_local < max_point_local)
528+
min_point_local = torch.clamp(min_point_local, -1, 1)
529+
max_point_local = torch.clamp(max_point_local, -1, 1)
530+
_, _, width = grid_values.vector_components_x.shape
531+
_, _, height = grid_values.vector_components_y.shape
532+
_, _, depth = grid_values.vector_components_z.shape
533+
resolution = grid_values.vector_components_x.new_tensor([width, height, depth])
534+
min_point_local01 = (min_point_local + 1) / 2
535+
max_point_local01 = (max_point_local + 1) / 2
536+
537+
if self.align_corners:
538+
minx, miny, minz = torch.floor(min_point_local01 * (resolution - 1)).long()
539+
maxx, maxy, maxz = torch.ceil(max_point_local01 * (resolution - 1)).long()
540+
else:
541+
minx, miny, minz = torch.floor(min_point_local01 * resolution - 0.5).long()
542+
maxx, maxy, maxz = torch.ceil(max_point_local01 * resolution - 0.5).long()
543+
544+
return CPFactorizedVoxelGridValues(
545+
vector_components_x=grid_values.vector_components_x[:, :, minx : maxx + 1],
546+
vector_components_y=grid_values.vector_components_y[:, :, miny : maxy + 1],
547+
vector_components_z=grid_values.vector_components_z[:, :, minz : maxz + 1],
548+
basis_matrix=grid_values.basis_matrix,
549+
)
550+
391551

392552
@dataclass
393553
class VMFactorizedVoxelGridValues(VoxelGridValuesBase):
@@ -585,6 +745,46 @@ def get_shapes(self, epoch: int) -> Dict[str, Tuple]:
585745

586746
return shape_dict
587747

748+
# pyre-ignore[14]
749+
def crop_local(
750+
self,
751+
min_point_local: torch.Tensor,
752+
max_point_local: torch.Tensor,
753+
grid_values: VMFactorizedVoxelGridValues,
754+
) -> VMFactorizedVoxelGridValues:
755+
assert torch.all(min_point_local < max_point_local)
756+
min_point_local = torch.clamp(min_point_local, -1, 1)
757+
max_point_local = torch.clamp(max_point_local, -1, 1)
758+
_, _, width = grid_values.vector_components_x.shape
759+
_, _, height = grid_values.vector_components_y.shape
760+
_, _, depth = grid_values.vector_components_z.shape
761+
resolution = grid_values.vector_components_x.new_tensor([width, height, depth])
762+
min_point_local01 = (min_point_local + 1) / 2
763+
max_point_local01 = (max_point_local + 1) / 2
764+
765+
if self.align_corners:
766+
minx, miny, minz = torch.floor(min_point_local01 * (resolution - 1)).long()
767+
maxx, maxy, maxz = torch.ceil(max_point_local01 * (resolution - 1)).long()
768+
else:
769+
minx, miny, minz = torch.floor(min_point_local01 * resolution - 0.5).long()
770+
maxx, maxy, maxz = torch.ceil(max_point_local01 * resolution - 0.5).long()
771+
772+
return VMFactorizedVoxelGridValues(
773+
vector_components_x=grid_values.vector_components_x[:, :, minx : maxx + 1],
774+
vector_components_y=grid_values.vector_components_y[:, :, miny : maxy + 1],
775+
vector_components_z=grid_values.vector_components_z[:, :, minz : maxz + 1],
776+
matrix_components_xy=grid_values.matrix_components_xy[
777+
:, :, minx : maxx + 1, miny : maxy + 1
778+
],
779+
matrix_components_yz=grid_values.matrix_components_yz[
780+
:, :, miny : maxy + 1, minz : maxz + 1
781+
],
782+
matrix_components_xz=grid_values.matrix_components_xz[
783+
:, :, minx : maxx + 1, minz : maxz + 1
784+
],
785+
basis_matrix=grid_values.basis_matrix,
786+
)
787+
588788

589789
# pyre-fixme[13]: Attribute `voxel_grid` is never initialized.
590790
class VoxelGridModule(Configurable, torch.nn.Module):
@@ -771,6 +971,37 @@ def get_device(self) -> torch.device:
771971
# pyre-ignore[29]
772972
return next(val for val in self.params.values() if val is not None).device
773973

974+
def crop_self(self, min_point: torch.Tensor, max_point: torch.Tensor) -> None:
975+
"""
976+
Crops self to only represent points between min_point and max_point (inclusive).
977+
978+
Args:
979+
min_point: torch.Tensor of shape (3,). Has x, y and z coordinates
980+
smaller or equal to all other occupied points.
981+
max_point: torch.Tensor of shape (3,). Has x, y and z coordinates
982+
bigger or equal to all other occupied points.
983+
Returns:
984+
nothing
985+
"""
986+
locator = self._get_volume_locator()
987+
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
988+
# torch.nn.modules.module.Module]` is not a function.
989+
old_grid_values = self.voxel_grid.values_type(**self.params)
990+
new_grid_values = self.voxel_grid.crop_world(
991+
min_point, max_point, old_grid_values, locator
992+
)
993+
grid_values, _ = self.voxel_grid.change_resolution(
994+
new_grid_values, grid_values_with_wanted_resolution=old_grid_values
995+
)
996+
# pyre-ignore [16]
997+
self.params = torch.nn.ParameterDict(
998+
{k: v for k, v in vars(grid_values).items()}
999+
)
1000+
# New center of voxel grid is the middle point between max and min points.
1001+
self.translation = tuple((max_point + min_point) / 2)
1002+
# new extents of voxel grid are distances between min and max points
1003+
self.extents = tuple(max_point - min_point)
1004+
7741005
def _get_volume_locator(self) -> VolumeLocator:
7751006
"""
7761007
Returns VolumeLocator calculated from `extents` and `translation` members.

0 commit comments

Comments
 (0)