14
14
15
15
import logging
16
16
17
+ from enum import Enum
17
18
from typing import Optional , Tuple
18
19
19
20
import torch
30
31
logger = logging .getLogger (__name__ )
31
32
32
33
34
+ class DecoderActivation (Enum ):
35
+ RELU = "relu"
36
+ SOFTPLUS = "softplus"
37
+ SIGMOID = "sigmoid"
38
+ IDENTITY = "identity"
39
+
40
+
33
41
class DecoderFunctionBase (ReplaceableBase , torch .nn .Module ):
34
42
"""
35
43
Decoding function is a torch.nn.Module which takes the embedding of a location in
@@ -71,11 +79,16 @@ class ElementwiseDecoder(DecoderFunctionBase):
71
79
72
80
scale : float = 1
73
81
shift : float = 0
74
- operation : str = "identity"
82
+ operation : DecoderActivation = DecoderActivation . IDENTITY
75
83
76
84
def __post_init__ (self ):
77
85
super ().__post_init__ ()
78
- if self .operation not in ["relu" , "softplus" , "sigmoid" , "identity" ]:
86
+ if self .operation not in [
87
+ DecoderActivation .RELU ,
88
+ DecoderActivation .SOFTPLUS ,
89
+ DecoderActivation .SIGMOID ,
90
+ DecoderActivation .IDENTITY ,
91
+ ]:
79
92
raise ValueError (
80
93
"`operation` can only be `relu`, `softplus`, `sigmoid` or identity."
81
94
)
@@ -84,11 +97,11 @@ def forward(
84
97
self , features : torch .Tensor , z : Optional [torch .Tensor ] = None
85
98
) -> torch .Tensor :
86
99
transfomed_input = features * self .scale + self .shift
87
- if self .operation == "softplus" :
100
+ if self .operation == DecoderActivation . SOFTPLUS :
88
101
return torch .nn .functional .softplus (transfomed_input )
89
- if self .operation == "relu" :
102
+ if self .operation == DecoderActivation . RELU :
90
103
return torch .nn .functional .relu (transfomed_input )
91
- if self .operation == "sigmoid" :
104
+ if self .operation == DecoderActivation . SIGMOID :
92
105
return torch .nn .functional .sigmoid (transfomed_input )
93
106
return transfomed_input
94
107
@@ -104,7 +117,15 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
104
117
appends a skip tensor `z` to the output of the preceding layer.
105
118
106
119
Note that this follows the architecture described in the Supplementary
107
- Material (Fig. 7) of [1].
120
+ Material (Fig. 7) of [1], for which keep the defaults for:
121
+ - `last_layer_bias_init` to None
122
+ - `last_activation` to "relu"
123
+ - `use_xavier_init` to `true`
124
+
125
+ If you want to use this as a part of the color prediction in TensoRF model set:
126
+ - `last_layer_bias_init` to 0
127
+ - `last_activation` to "sigmoid"
128
+ - `use_xavier_init` to `False`
108
129
109
130
References:
110
131
[1] Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik
@@ -121,6 +142,12 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
121
142
hidden_dim: The number of hidden units of the MLP.
122
143
input_skips: The list of layer indices at which we append the skip
123
144
tensor `z`.
145
+ last_layer_bias_init: If set then all the biases in the last layer
146
+ are initialized to that value.
147
+ last_activation: Which activation to use in the last layer. Options are:
148
+ "relu", "softplus", "sigmoid" and "identity". Default is "relu".
149
+ use_xavier_init: If True uses xavier init for all linear layer weights.
150
+ Otherwise the default PyTorch initialization is used. Default True.
124
151
"""
125
152
126
153
n_layers : int = 8
@@ -130,10 +157,30 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
130
157
hidden_dim : int = 256
131
158
input_skips : Tuple [int , ...] = (5 ,)
132
159
skip_affine_trans : bool = False
133
- no_last_relu = False
160
+ last_layer_bias_init : Optional [float ] = None
161
+ last_activation : DecoderActivation = DecoderActivation .RELU
162
+ use_xavier_init : bool = True
134
163
135
164
def __post_init__ (self ):
136
165
super ().__init__ ()
166
+
167
+ if self .last_activation not in [
168
+ DecoderActivation .RELU ,
169
+ DecoderActivation .SOFTPLUS ,
170
+ DecoderActivation .SIGMOID ,
171
+ DecoderActivation .IDENTITY ,
172
+ ]:
173
+ raise ValueError (
174
+ "`last_activation` can only be `relu`,"
175
+ " `softplus`, `sigmoid` or identity."
176
+ )
177
+ last_activation = {
178
+ DecoderActivation .RELU : torch .nn .ReLU (True ),
179
+ DecoderActivation .SOFTPLUS : torch .nn .Softplus (),
180
+ DecoderActivation .SIGMOID : torch .nn .Sigmoid (),
181
+ DecoderActivation .IDENTITY : torch .nn .Identity (),
182
+ }[self .last_activation ]
183
+
137
184
layers = []
138
185
skip_affine_layers = []
139
186
for layeri in range (self .n_layers ):
@@ -149,11 +196,14 @@ def __post_init__(self):
149
196
dimin = self .hidden_dim + self .skip_dim
150
197
151
198
linear = torch .nn .Linear (dimin , dimout )
152
- _xavier_init (linear )
199
+ if self .use_xavier_init :
200
+ _xavier_init (linear )
201
+ if layeri == self .n_layers - 1 and self .last_layer_bias_init is not None :
202
+ torch .nn .init .constant_ (linear .bias , self .last_layer_bias_init )
153
203
layers .append (
154
204
torch .nn .Sequential (linear , torch .nn .ReLU (True ))
155
- if not self . no_last_relu or layeri + 1 < self .n_layers
156
- else linear
205
+ if not layeri + 1 < self .n_layers
206
+ else torch . nn . Sequential ( linear , last_activation )
157
207
)
158
208
self .mlp = torch .nn .ModuleList (layers )
159
209
if self .skip_affine_trans :
@@ -164,8 +214,9 @@ def __post_init__(self):
164
214
def _make_affine_layer (self , input_dim , hidden_dim ):
165
215
l1 = torch .nn .Linear (input_dim , hidden_dim * 2 )
166
216
l2 = torch .nn .Linear (hidden_dim * 2 , hidden_dim * 2 )
167
- _xavier_init (l1 )
168
- _xavier_init (l2 )
217
+ if self .use_xavier_init :
218
+ _xavier_init (l1 )
219
+ _xavier_init (l2 )
169
220
return torch .nn .Sequential (l1 , torch .nn .ReLU (True ), l2 )
170
221
171
222
def _apply_affine_layer (self , layer , x , z ):
0 commit comments