Skip to content

Commit a819ecb

Browse files
bottlerfacebook-github-bot
authored andcommitted
MLP last layer config
Summary: Added initialization configuration for the last layer of the MLP decoding function. You can now set: - last activation function (tensorf uses sigmoid) - last bias init (tensorf uses 0, because of sigmoid ofc) - option to use xavier initialization (we use relu so this should not be set) Reviewed By: davnov134 Differential Revision: D40304981 fbshipit-source-id: ec398eb2235164ae85cb7c09b9660e843490ea04
1 parent a2659e1 commit a819ecb

File tree

1 file changed

+63
-12
lines changed

1 file changed

+63
-12
lines changed

pytorch3d/implicitron/models/implicit_function/decoding_functions.py

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import logging
1616

17+
from enum import Enum
1718
from typing import Optional, Tuple
1819

1920
import torch
@@ -30,6 +31,13 @@
3031
logger = logging.getLogger(__name__)
3132

3233

34+
class DecoderActivation(Enum):
35+
RELU = "relu"
36+
SOFTPLUS = "softplus"
37+
SIGMOID = "sigmoid"
38+
IDENTITY = "identity"
39+
40+
3341
class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
3442
"""
3543
Decoding function is a torch.nn.Module which takes the embedding of a location in
@@ -71,11 +79,16 @@ class ElementwiseDecoder(DecoderFunctionBase):
7179

7280
scale: float = 1
7381
shift: float = 0
74-
operation: str = "identity"
82+
operation: DecoderActivation = DecoderActivation.IDENTITY
7583

7684
def __post_init__(self):
7785
super().__post_init__()
78-
if self.operation not in ["relu", "softplus", "sigmoid", "identity"]:
86+
if self.operation not in [
87+
DecoderActivation.RELU,
88+
DecoderActivation.SOFTPLUS,
89+
DecoderActivation.SIGMOID,
90+
DecoderActivation.IDENTITY,
91+
]:
7992
raise ValueError(
8093
"`operation` can only be `relu`, `softplus`, `sigmoid` or identity."
8194
)
@@ -84,11 +97,11 @@ def forward(
8497
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
8598
) -> torch.Tensor:
8699
transfomed_input = features * self.scale + self.shift
87-
if self.operation == "softplus":
100+
if self.operation == DecoderActivation.SOFTPLUS:
88101
return torch.nn.functional.softplus(transfomed_input)
89-
if self.operation == "relu":
102+
if self.operation == DecoderActivation.RELU:
90103
return torch.nn.functional.relu(transfomed_input)
91-
if self.operation == "sigmoid":
104+
if self.operation == DecoderActivation.SIGMOID:
92105
return torch.nn.functional.sigmoid(transfomed_input)
93106
return transfomed_input
94107

@@ -104,7 +117,15 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
104117
appends a skip tensor `z` to the output of the preceding layer.
105118
106119
Note that this follows the architecture described in the Supplementary
107-
Material (Fig. 7) of [1].
120+
Material (Fig. 7) of [1], for which keep the defaults for:
121+
- `last_layer_bias_init` to None
122+
- `last_activation` to "relu"
123+
- `use_xavier_init` to `true`
124+
125+
If you want to use this as a part of the color prediction in TensoRF model set:
126+
- `last_layer_bias_init` to 0
127+
- `last_activation` to "sigmoid"
128+
- `use_xavier_init` to `False`
108129
109130
References:
110131
[1] Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik
@@ -121,6 +142,12 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
121142
hidden_dim: The number of hidden units of the MLP.
122143
input_skips: The list of layer indices at which we append the skip
123144
tensor `z`.
145+
last_layer_bias_init: If set then all the biases in the last layer
146+
are initialized to that value.
147+
last_activation: Which activation to use in the last layer. Options are:
148+
"relu", "softplus", "sigmoid" and "identity". Default is "relu".
149+
use_xavier_init: If True uses xavier init for all linear layer weights.
150+
Otherwise the default PyTorch initialization is used. Default True.
124151
"""
125152

126153
n_layers: int = 8
@@ -130,10 +157,30 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
130157
hidden_dim: int = 256
131158
input_skips: Tuple[int, ...] = (5,)
132159
skip_affine_trans: bool = False
133-
no_last_relu = False
160+
last_layer_bias_init: Optional[float] = None
161+
last_activation: DecoderActivation = DecoderActivation.RELU
162+
use_xavier_init: bool = True
134163

135164
def __post_init__(self):
136165
super().__init__()
166+
167+
if self.last_activation not in [
168+
DecoderActivation.RELU,
169+
DecoderActivation.SOFTPLUS,
170+
DecoderActivation.SIGMOID,
171+
DecoderActivation.IDENTITY,
172+
]:
173+
raise ValueError(
174+
"`last_activation` can only be `relu`,"
175+
" `softplus`, `sigmoid` or identity."
176+
)
177+
last_activation = {
178+
DecoderActivation.RELU: torch.nn.ReLU(True),
179+
DecoderActivation.SOFTPLUS: torch.nn.Softplus(),
180+
DecoderActivation.SIGMOID: torch.nn.Sigmoid(),
181+
DecoderActivation.IDENTITY: torch.nn.Identity(),
182+
}[self.last_activation]
183+
137184
layers = []
138185
skip_affine_layers = []
139186
for layeri in range(self.n_layers):
@@ -149,11 +196,14 @@ def __post_init__(self):
149196
dimin = self.hidden_dim + self.skip_dim
150197

151198
linear = torch.nn.Linear(dimin, dimout)
152-
_xavier_init(linear)
199+
if self.use_xavier_init:
200+
_xavier_init(linear)
201+
if layeri == self.n_layers - 1 and self.last_layer_bias_init is not None:
202+
torch.nn.init.constant_(linear.bias, self.last_layer_bias_init)
153203
layers.append(
154204
torch.nn.Sequential(linear, torch.nn.ReLU(True))
155-
if not self.no_last_relu or layeri + 1 < self.n_layers
156-
else linear
205+
if not layeri + 1 < self.n_layers
206+
else torch.nn.Sequential(linear, last_activation)
157207
)
158208
self.mlp = torch.nn.ModuleList(layers)
159209
if self.skip_affine_trans:
@@ -164,8 +214,9 @@ def __post_init__(self):
164214
def _make_affine_layer(self, input_dim, hidden_dim):
165215
l1 = torch.nn.Linear(input_dim, hidden_dim * 2)
166216
l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2)
167-
_xavier_init(l1)
168-
_xavier_init(l2)
217+
if self.use_xavier_init:
218+
_xavier_init(l1)
219+
_xavier_init(l2)
169220
return torch.nn.Sequential(l1, torch.nn.ReLU(True), l2)
170221

171222
def _apply_affine_layer(self, layer, x, z):

0 commit comments

Comments
 (0)