Skip to content

Commit f711c4b

Browse files
shapovalovfacebook-github-bot
authored andcommitted
Fix parameters not wrapped with nn.Parameter, antialiasing compatibility
Summary: Some things fail if a parameter is not wraped; in particular, it prevented other tensors moving to GPU. Reviewed By: bottler Differential Revision: D40819932 fbshipit-source-id: a23b38ceacd7f0dc131cb0355fef1178e3e2f7fd
1 parent 88620b6 commit f711c4b

File tree

3 files changed

+47
-28
lines changed

3 files changed

+47
-28
lines changed

pytorch3d/implicitron/models/implicit_function/decoding_functions.py

+13-17
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class ElementwiseDecoder(DecoderFunctionBase):
7575
shift: a scalar which is added to the scaled input before performing
7676
the operation. Defaults to 0.
7777
operation: which operation to perform on the transformed input. Options are:
78-
`relu`, `softplus`, `sigmoid` and `identity`. Defaults to `identity`.
78+
`RELU`, `SOFTPLUS`, `SIGMOID` or `IDENTITY`. Defaults to `IDENTITY`.
7979
"""
8080

8181
scale: float = 1
@@ -91,7 +91,7 @@ def __post_init__(self):
9191
DecoderActivation.IDENTITY,
9292
]:
9393
raise ValueError(
94-
"`operation` can only be `relu`, `softplus`, `sigmoid` or identity."
94+
"`operation` can only be `RELU`, `SOFTPLUS`, `SIGMOID` or `IDENTITY`."
9595
)
9696

9797
def forward(
@@ -165,22 +165,18 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
165165
def __post_init__(self):
166166
super().__init__()
167167

168-
if self.last_activation not in [
169-
DecoderActivation.RELU,
170-
DecoderActivation.SOFTPLUS,
171-
DecoderActivation.SIGMOID,
172-
DecoderActivation.IDENTITY,
173-
]:
168+
try:
169+
last_activation = {
170+
DecoderActivation.RELU: torch.nn.ReLU(True),
171+
DecoderActivation.SOFTPLUS: torch.nn.Softplus(),
172+
DecoderActivation.SIGMOID: torch.nn.Sigmoid(),
173+
DecoderActivation.IDENTITY: torch.nn.Identity(),
174+
}[self.last_activation]
175+
except KeyError as e:
174176
raise ValueError(
175-
"`last_activation` can only be `relu`,"
176-
" `softplus`, `sigmoid` or identity."
177-
)
178-
last_activation = {
179-
DecoderActivation.RELU: torch.nn.ReLU(True),
180-
DecoderActivation.SOFTPLUS: torch.nn.Softplus(),
181-
DecoderActivation.SIGMOID: torch.nn.Sigmoid(),
182-
DecoderActivation.IDENTITY: torch.nn.Identity(),
183-
}[self.last_activation]
177+
"`last_activation` can only be `RELU`,"
178+
" `SOFTPLUS`, `SIGMOID` or `IDENTITY`."
179+
) from e
184180

185181
layers = []
186182
skip_affine_layers = []

pytorch3d/implicitron/models/implicit_function/voxel_grid.py

+28-5
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
1616
"""
1717

18+
import warnings
1819
from collections.abc import Mapping
1920
from dataclasses import dataclass, field
20-
from typing import Callable, ClassVar, Dict, Iterator, List, Optional, Tuple, Type
21+
22+
from distutils.version import LooseVersion
23+
from typing import Any, Callable, ClassVar, Dict, Iterator, List, Optional, Tuple, Type
2124

2225
import torch
2326
from omegaconf import DictConfig
@@ -67,7 +70,9 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
6770
padding: str = "zeros"
6871
mode: str = "bilinear"
6972
n_features: int = 1
70-
resolution_changes: Dict[int, List[int]] = field(
73+
# return the line below once we drop OmegaConf 2.1 support
74+
# resolution_changes: Dict[int, List[int]] = field(
75+
resolution_changes: Dict[int, Any] = field(
7176
default_factory=lambda: {0: [128, 128, 128]}
7277
)
7378

@@ -212,6 +217,13 @@ def change_resolution(
212217
+ "| 'bicubic' | 'linear' | 'area' | 'nearest-exact'"
213218
)
214219

220+
interpolate_has_antialias = LooseVersion(torch.__version__) >= "1.11"
221+
222+
if antialias and not interpolate_has_antialias:
223+
warnings.warn("Antialiased interpolation requires PyTorch 1.11+; ignoring")
224+
225+
interp_kwargs = {"antialias": antialias} if interpolate_has_antialias else {}
226+
215227
def change_individual_resolution(tensor, wanted_resolution):
216228
if mode == "linear":
217229
n_dim = len(wanted_resolution)
@@ -223,8 +235,8 @@ def change_individual_resolution(tensor, wanted_resolution):
223235
size=wanted_resolution,
224236
mode=new_mode,
225237
align_corners=align_corners,
226-
antialias=antialias,
227238
recompute_scale_factor=False,
239+
**interp_kwargs,
228240
)
229241

230242
if epoch is not None:
@@ -880,7 +892,14 @@ def set_voxel_grid_parameters(self, params: VoxelGridValuesBase) -> None:
880892
"""
881893
if self.hold_voxel_grid_as_parameters:
882894
# pyre-ignore [16]
883-
self.params = torch.nn.ParameterDict(vars(params))
895+
# Nones are converted to empty tensors by Parameter()
896+
self.params = torch.nn.ParameterDict(
897+
{
898+
k: torch.nn.Parameter(val)
899+
for k, val in vars(params).items()
900+
if val is not None
901+
}
902+
)
884903
else:
885904
# Torch Module to hold parameters since they can only be registered
886905
# at object level.
@@ -1011,7 +1030,11 @@ def crop_self(self, min_point: torch.Tensor, max_point: torch.Tensor) -> None:
10111030
)
10121031
# pyre-ignore [16]
10131032
self.params = torch.nn.ParameterDict(
1014-
{k: v for k, v in vars(grid_values).items()}
1033+
{
1034+
k: torch.nn.Parameter(val)
1035+
for k, val in vars(grid_values).items()
1036+
if val is not None
1037+
}
10151038
)
10161039
# New center of voxel grid is the middle point between max and min points.
10171040
self.translation = tuple((max_point + min_point) / 2)

pytorch3d/implicitron/models/implicit_function/voxel_grid_implicit_function.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -527,10 +527,10 @@ def _get_scaffold(self, epoch: int) -> bool:
527527
return False
528528

529529
@classmethod
530-
def decoder_density_tweak_args(cls, type, args: DictConfig) -> None:
530+
def decoder_density_tweak_args(cls, type_, args: DictConfig) -> None:
531531
args.pop("input_dim", None)
532532

533-
def create_decoder_density_impl(self, type, args: DictConfig) -> None:
533+
def create_decoder_density_impl(self, type_, args: DictConfig) -> None:
534534
"""
535535
Decoding functions come after harmonic embedding and voxel grid. In order to not
536536
calculate the input dimension of the decoder in the config file this function
@@ -548,18 +548,18 @@ def create_decoder_density_impl(self, type, args: DictConfig) -> None:
548548
embedder_args["append_input"],
549549
)
550550

551-
cls = registry.get(DecoderFunctionBase, type)
551+
cls = registry.get(DecoderFunctionBase, type_)
552552
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
553553
if need_input_dim:
554554
self.decoder_density = cls(input_dim=input_dim, **args)
555555
else:
556556
self.decoder_density = cls(**args)
557557

558558
@classmethod
559-
def decoder_color_tweak_args(cls, type, args: DictConfig) -> None:
559+
def decoder_color_tweak_args(cls, type_, args: DictConfig) -> None:
560560
args.pop("input_dim", None)
561561

562-
def create_decoder_color_impl(self, type, args: DictConfig) -> None:
562+
def create_decoder_color_impl(self, type_, args: DictConfig) -> None:
563563
"""
564564
Decoding functions come after harmonic embedding and voxel grid. In order to not
565565
calculate the input dimension of the decoder in the config file this function
@@ -587,7 +587,7 @@ def create_decoder_color_impl(self, type, args: DictConfig) -> None:
587587

588588
input_dim = input_dim0 + input_dim1
589589

590-
cls = registry.get(DecoderFunctionBase, type)
590+
cls = registry.get(DecoderFunctionBase, type_)
591591
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
592592
if need_input_dim:
593593
self.decoder_color = cls(input_dim=input_dim, **args)

0 commit comments

Comments
 (0)