Skip to content

Commit 4a0dcd4

Browse files
LollikitLollikit
authored andcommitted
Fix the bug related to 'repnr' during the finetuning process
1 parent 75f1081 commit 4a0dcd4

File tree

1 file changed

+126
-68
lines changed

1 file changed

+126
-68
lines changed

led/archs/repnr_utils.py

Lines changed: 126 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,18 @@
77
from torch.nn.init import kaiming_normal_, kaiming_uniform_
88
from torch.nn.modules.utils import _reverse_repeat_tuple
99

10+
1011
def zero_init_(x, a=None):
1112
x.zero_()
1213

14+
1315
def build_repnr_arch_from_base(base_arch, **repnr_kwargs):
1416
base_arch = deepcopy(base_arch)
15-
dont_convert_module = [] if 'dont_convert_module' not in repnr_kwargs \
16-
else repnr_kwargs.pop('dont_convert_module')
17+
dont_convert_module = (
18+
[]
19+
if "dont_convert_module" not in repnr_kwargs
20+
else repnr_kwargs.pop("dont_convert_module")
21+
)
1722

1823
def recursive_converter(base_arch, **repnr_kwargs):
1924
if isinstance(base_arch, nn.Conv2d):
@@ -30,14 +35,16 @@ def recursive_converter(base_arch, **repnr_kwargs):
3035

3136

3237
class RepNRConv2d(nn.Module):
33-
def _init_conv(self, init_type='kaiming_uniform_'):
38+
def _init_conv(self, init_type="kaiming_uniform_"):
3439
weight = torch.zeros_like(self.main_weight)
3540
init_func = eval(init_type)
3641
init_func(weight, a=math.sqrt(5))
3742
fan_in, _ = init._calculate_fan_in_and_fan_out(weight)
38-
bias = torch.zeros((weight.size(0), ),
39-
dtype=self.main_weight.dtype,
40-
device=self.main_weight.device)
43+
bias = torch.zeros(
44+
(weight.size(0),),
45+
dtype=self.main_weight.dtype,
46+
device=self.main_weight.device,
47+
)
4148
if fan_in != 0:
4249
bound = 1 / math.sqrt(fan_in)
4350
init.uniform_(bias, -bound, bound)
@@ -49,56 +56,84 @@ def _init_from_conv2d(self, conv2d: nn.Conv2d):
4956
self.kernel_size = conv2d.kernel_size
5057
self.stride = conv2d.stride
5158
self.padding = _reverse_repeat_tuple(conv2d.padding, 2)
52-
self.padding_mode = 'constant' if conv2d.padding_mode == 'zeros' else conv2d.padding_mode
59+
self.padding_mode = (
60+
"constant" if conv2d.padding_mode == "zeros" else conv2d.padding_mode
61+
)
5362
self.dilation = conv2d.dilation
5463
self.groups = conv2d.groups
5564
self.bias = conv2d.bias is not None
5665

5766
main_weight = conv2d.weight.data.clone()
5867
main_bias = conv2d.bias.data.clone() if self.bias else None
5968
self.main_weight = nn.Parameter(main_weight, requires_grad=True)
60-
self.main_bias = nn.Parameter(main_bias, requires_grad=True) if main_bias is not None else None
69+
self.main_bias = (
70+
nn.Parameter(main_bias, requires_grad=True)
71+
if main_bias is not None
72+
else None
73+
)
6174

6275
def _init_alignments(self, align_opts):
63-
align_init_weight = align_opts.get('init_weight', 1.0)
64-
if 'init_bias' in align_opts:
65-
align_init_bias = align_opts['init_bias']
76+
align_init_weight = align_opts.get("init_weight", 1.0)
77+
if "init_bias" in align_opts:
78+
align_init_bias = align_opts["init_bias"]
6679
align_bias = True
6780
else:
6881
align_bias = False
6982
align_channles = self.in_channels
7083

7184
self.align_weights = nn.ParameterList(
7285
[
73-
nn.Parameter(torch.ones((1, align_channles, 1, 1),
74-
dtype=self.main_weight.dtype,
75-
device=self.main_weight.device) * align_init_weight,
76-
requires_grad=True)
86+
nn.Parameter(
87+
torch.ones(
88+
(1, align_channles, 1, 1),
89+
dtype=self.main_weight.dtype,
90+
device=self.main_weight.device,
91+
)
92+
* align_init_weight,
93+
requires_grad=True,
94+
)
7795
for _ in range(self.branch_num + 1)
7896
]
7997
)
80-
self.align_biases = nn.ParameterList(
81-
[
82-
nn.Parameter(torch.ones((1, align_channles, 1, 1),
83-
dtype=self.main_weight.dtype,
84-
device=self.main_weight.device) * align_init_bias,
85-
requires_grad=True)
86-
for _ in range(self.branch_num + 1)
87-
]
88-
) if align_bias else None
98+
self.align_biases = (
99+
nn.ParameterList(
100+
[
101+
nn.Parameter(
102+
torch.ones(
103+
(1, align_channles, 1, 1),
104+
dtype=self.main_weight.dtype,
105+
device=self.main_weight.device,
106+
)
107+
* align_init_bias,
108+
requires_grad=True,
109+
)
110+
for _ in range(self.branch_num + 1)
111+
]
112+
)
113+
if align_bias
114+
else None
115+
)
89116

90117
def _init_aux_conv(self, aux_conv_opts):
91-
if 'init' not in aux_conv_opts:
92-
aux_conv_opts['init'] = 'kaiming_normal_'
118+
if "init" not in aux_conv_opts:
119+
aux_conv_opts["init"] = "kaiming_normal_"
93120

94-
aux_weight, aux_bias = self._init_conv(init_type=aux_conv_opts['init'])
121+
aux_weight, aux_bias = self._init_conv(init_type=aux_conv_opts["init"])
95122
self.aux_weight = nn.Parameter(aux_weight, requires_grad=True)
96-
self.aux_bias = nn.Parameter(torch.zeros_like(aux_bias), requires_grad=True) \
97-
if aux_conv_opts.get('bias', False) else torch.zeros_like(aux_bias)
123+
self.aux_bias = (
124+
nn.Parameter(torch.zeros_like(aux_bias), requires_grad=True)
125+
if aux_conv_opts.get("bias", False)
126+
else torch.zeros_like(aux_bias)
127+
)
98128

99-
def __init__(self, conv2d, branch_num,
100-
align_opts, aux_conv_opts=None,
101-
forward_type='reparameterize'):
129+
def __init__(
130+
self,
131+
conv2d,
132+
branch_num,
133+
align_opts,
134+
aux_conv_opts=None,
135+
forward_type="reparameterize",
136+
):
102137
super().__init__()
103138
self.branch_num = branch_num
104139
self.forward_type = forward_type
@@ -111,33 +146,36 @@ def __init__(self, conv2d, branch_num,
111146
self._init_aux_conv(aux_conv_opts)
112147

113148
self.cur_branch = -1
114-
self.forward = self._reparameterize_forward if forward_type == 'reparameterize' \
149+
self.forward = (
150+
self._reparameterize_forward
151+
if forward_type == "reparameterize"
115152
else self._trivial_forward
153+
)
116154

117155
def switch_forward_type(self, *, trivial=False, reparameterize=False):
118156
assert not (trivial and reparameterize)
119157
if trivial:
120158
self.forward = self._trivial_forward
121-
self.forward_type = 'trivial'
159+
self.forward_type = "trivial"
122160
elif reparameterize:
123161
self.forward = self._reparameterize_forward
124-
self.forward_type = 'reparameterize'
162+
self.forward_type = "reparameterize"
125163

126164
@staticmethod
127165
def _sequential_reparamterize(k1, b1, k2, b2):
128166
# k1, b1 is the weight and bias of alignment
129167
def depthwise_to_normal(k, padding):
130168
k = k.reshape(-1)
131169
k = torch.diag(k).unsqueeze(-1).unsqueeze(-1)
132-
k = F.pad(k, _reverse_repeat_tuple(padding, 2), mode='constant', value=0.0)
170+
k = F.pad(k, _reverse_repeat_tuple(padding, 2), mode="constant", value=0.0)
133171
return k
134172

135173
def bias_pad(b, padding):
136-
return F.pad(b, _reverse_repeat_tuple(padding, 2), mode='replicate')
174+
return F.pad(b, _reverse_repeat_tuple(padding, 2), mode="replicate")
137175

138176
padding = (k2.shape[-2] - 1) // 2, (k2.shape[-1] - 1) // 2
139177
k1 = depthwise_to_normal(k1, padding)
140-
k = F.conv2d(k2, k1, stride=1, padding='same')
178+
k = F.conv2d(k2, k1, stride=1, padding="same")
141179
b = F.conv2d(bias_pad(b1, padding), k2, bias=b2, stride=1).reshape(-1)
142180
return k, b
143181

@@ -149,15 +187,20 @@ def _parallel_reparamterize(k1, b1, k2, b2):
149187
def _weight_and_bias(self):
150188
index = self.cur_branch
151189
align_weight = self.align_weights[index]
152-
align_bias = self.align_biases[index] if self.align_biases is not None \
190+
align_bias = (
191+
self.align_biases[index]
192+
if self.align_biases is not None
153193
else torch.zeros_like(align_weight)
194+
)
154195

155196
main_weight, main_bias = self._sequential_reparamterize(
156-
align_weight, align_bias, self.main_weight, self.main_bias)
197+
align_weight, align_bias, self.main_weight, self.main_bias
198+
)
157199

158-
if hasattr(self, 'aux_weight'):
200+
if hasattr(self, "aux_weight"):
159201
main_weight, main_bias = self._parallel_reparamterize(
160-
main_weight, main_bias, self.aux_weight, self.aux_bias)
202+
main_weight, main_bias, self.aux_weight, self.aux_bias
203+
)
161204

162205
return main_weight, main_bias
163206

@@ -170,16 +213,23 @@ def _reparameterize_forward(self, x):
170213
def _trivial_forward(self, x):
171214
index = self.cur_branch
172215
align_weight = self.align_weights[index]
173-
align_bias = self.align_biases[index] if self.align_biases is not None \
216+
align_bias = (
217+
self.align_biases[index]
218+
if self.align_biases is not None
174219
else torch.zeros_like(align_weight)
220+
)
175221

176222
aligned_x = x * align_weight + align_bias
177223
padded_aligned_x = F.pad(aligned_x, self.padding, self.padding_mode, value=0.0)
178-
main_x = F.conv2d(padded_aligned_x, self.main_weight, self.main_bias, stride=self.stride)
224+
main_x = F.conv2d(
225+
padded_aligned_x, self.main_weight, self.main_bias, stride=self.stride
226+
)
179227

180-
if hasattr(self, 'aux_weight'):
228+
if hasattr(self, "aux_weight"):
181229
padded_x = F.pad(x, self.padding, self.padding_mode, value=0.0)
182-
aux_x = F.conv2d(padded_x, self.aux_weight, self.aux_bias, stride=self.stride)
230+
aux_x = F.conv2d(
231+
padded_x, self.aux_weight, self.aux_bias, stride=self.stride
232+
)
183233
main_x = main_x + aux_x
184234

185235
return main_x
@@ -188,24 +238,29 @@ def _trivial_forward_with_intermediate_features(self, x):
188238
index = self.cur_branch
189239
features = {}
190240
align_weight = self.align_weights[index]
191-
align_bias = self.align_biases[index] if self.align_biases is not None \
241+
align_bias = (
242+
self.align_biases[index]
243+
if self.align_biases is not None
192244
else torch.zeros_like(align_weight)
245+
)
193246

194-
features['in_feat'] = x
247+
features["in_feat"] = x
195248
aligned_x = x * align_weight + align_bias
196-
features['aligned_feat'] = aligned_x
249+
features["aligned_feat"] = aligned_x
197250

198251
padded_aligned_x = F.pad(aligned_x, self.padding, self.padding_mode, value=0.0)
199-
main_x = F.conv2d(padded_aligned_x, self.main_weight, self.main_bias, stride=self.stride)
200-
features['main_feat'] = main_x
252+
main_x = F.conv2d(
253+
padded_aligned_x, self.main_weight, self.main_bias, stride=self.stride
254+
)
255+
features["main_feat"] = main_x
201256

202-
if hasattr(self, 'aux_weight'):
257+
if hasattr(self, "aux_weight"):
203258
padded_x = F.pad(x, self.padding, self.padding_mode, value=0.0)
204259
aux_x = F.conv2d(padded_x, self.aux_weight, self.aux_bias, stride=1)
205-
features['aux_feat'] = aux_x
260+
features["aux_feat"] = aux_x
206261
main_x = main_x + aux_x
207262

208-
features['out_feat'] = aux_x
263+
features["out_feat"] = aux_x
209264
self.current_intermediate_features = features
210265
return main_x
211266

@@ -215,30 +270,30 @@ def __repr__(self):
215270
extra_repr = self.extra_repr()
216271
# empty string will be split into list ['']
217272
if extra_repr:
218-
extra_lines = extra_repr.split('\n')
273+
extra_lines = extra_repr.split("\n")
219274
child_lines = []
220275
lines = extra_lines + child_lines
221276

222-
main_str = self._get_name() + '('
277+
main_str = self._get_name() + "("
223278
if lines:
224279
# simple one-liner info, which most builtin Modules will use
225280
if len(extra_lines) == 1 and not child_lines:
226281
main_str += extra_lines[0]
227282
else:
228-
main_str += '\n ' + '\n '.join(lines) + '\n'
283+
main_str += "\n " + "\n ".join(lines) + "\n"
229284

230-
main_str += ')'
285+
main_str += ")"
231286
return main_str
232287

233288
def extra_repr(self) -> str:
234289
s = (
235-
'{in_channels}, {out_channels}, kernel_size={kernel_size}, '
236-
'stride={stride}, padding={padding}, padding_mode={padding_mode},\n'
237-
'branch_num={branch_num}, align_opts={align_opts}\n'
238-
'forward_type={forward_type}'
290+
"{in_channels}, {out_channels}, kernel_size={kernel_size}, "
291+
"stride={stride}, padding={padding}, padding_mode={padding_mode},\n"
292+
"branch_num={branch_num}, align_opts={align_opts}\n"
293+
"forward_type={forward_type}"
239294
)
240-
if hasattr(self, 'aux_weight'):
241-
s += ', aux_conv_opts={aux_conv_opts}'
295+
if hasattr(self, "aux_weight"):
296+
s += ", aux_conv_opts={aux_conv_opts}"
242297
return s.format(**self.__dict__)
243298

244299

@@ -287,7 +342,9 @@ def generalize_align_conv(align_conv: RepNRConv2d):
287342
generalize_align_conv(m)
288343

289344
@staticmethod
290-
def _set_requires_grad(align_conv: RepNRConv2d, *, pretrain=True, finetune=False, aux=True):
345+
def _set_requires_grad(
346+
align_conv: RepNRConv2d, *, pretrain=True, finetune=False, aux=True
347+
):
291348
assert not (pretrain and finetune)
292349
for i in range(align_conv.branch_num):
293350
align_conv.align_weights[i].requires_grad_(pretrain)
@@ -303,7 +360,7 @@ def _set_requires_grad(align_conv: RepNRConv2d, *, pretrain=True, finetune=False
303360
align_conv.main_bias.requires_grad_(pretrain)
304361

305362
# aux weight and bias
306-
if hasattr(align_conv, 'aux_weight') and aux:
363+
if hasattr(align_conv, "aux_weight") and aux:
307364
align_conv.aux_weight.requires_grad_(finetune)
308365
if isinstance(align_conv.aux_bias, nn.Parameter):
309366
align_conv.aux_bias.requires_grad_(finetune)
@@ -319,7 +376,7 @@ def finetune(self, *, aux=False):
319376
self._set_requires_grad(m, pretrain=False, finetune=True, aux=aux)
320377

321378
def __repr__(self):
322-
return f'{self._get_name()}: {str(self.repnr_module)}'
379+
return f"{self._get_name()}: {str(self.repnr_module)}"
323380

324381
@torch.no_grad()
325382
def deploy(self):
@@ -333,7 +390,8 @@ def deploy(m_bak, m):
333390
if isinstance(m, RepNRConv2d):
334391
weight, bias = m._weight_and_bias
335392
m_bak.weight.data = weight
336-
m_bak.bias.data = bias
393+
if m_bak.bias is not None:
394+
m_bak.bias.data = bias
337395
return
338396
if not has_repnr_conv(m):
339397
m_bak.load_state_dict(m.state_dict())

0 commit comments

Comments
 (0)