7
7
from torch .nn .init import kaiming_normal_ , kaiming_uniform_
8
8
from torch .nn .modules .utils import _reverse_repeat_tuple
9
9
10
+
10
11
def zero_init_ (x , a = None ):
11
12
x .zero_ ()
12
13
14
+
13
15
def build_repnr_arch_from_base (base_arch , ** repnr_kwargs ):
14
16
base_arch = deepcopy (base_arch )
15
- dont_convert_module = [] if 'dont_convert_module' not in repnr_kwargs \
16
- else repnr_kwargs .pop ('dont_convert_module' )
17
+ dont_convert_module = (
18
+ []
19
+ if "dont_convert_module" not in repnr_kwargs
20
+ else repnr_kwargs .pop ("dont_convert_module" )
21
+ )
17
22
18
23
def recursive_converter (base_arch , ** repnr_kwargs ):
19
24
if isinstance (base_arch , nn .Conv2d ):
@@ -30,14 +35,16 @@ def recursive_converter(base_arch, **repnr_kwargs):
30
35
31
36
32
37
class RepNRConv2d (nn .Module ):
33
- def _init_conv (self , init_type = ' kaiming_uniform_' ):
38
+ def _init_conv (self , init_type = " kaiming_uniform_" ):
34
39
weight = torch .zeros_like (self .main_weight )
35
40
init_func = eval (init_type )
36
41
init_func (weight , a = math .sqrt (5 ))
37
42
fan_in , _ = init ._calculate_fan_in_and_fan_out (weight )
38
- bias = torch .zeros ((weight .size (0 ), ),
39
- dtype = self .main_weight .dtype ,
40
- device = self .main_weight .device )
43
+ bias = torch .zeros (
44
+ (weight .size (0 ),),
45
+ dtype = self .main_weight .dtype ,
46
+ device = self .main_weight .device ,
47
+ )
41
48
if fan_in != 0 :
42
49
bound = 1 / math .sqrt (fan_in )
43
50
init .uniform_ (bias , - bound , bound )
@@ -49,56 +56,84 @@ def _init_from_conv2d(self, conv2d: nn.Conv2d):
49
56
self .kernel_size = conv2d .kernel_size
50
57
self .stride = conv2d .stride
51
58
self .padding = _reverse_repeat_tuple (conv2d .padding , 2 )
52
- self .padding_mode = 'constant' if conv2d .padding_mode == 'zeros' else conv2d .padding_mode
59
+ self .padding_mode = (
60
+ "constant" if conv2d .padding_mode == "zeros" else conv2d .padding_mode
61
+ )
53
62
self .dilation = conv2d .dilation
54
63
self .groups = conv2d .groups
55
64
self .bias = conv2d .bias is not None
56
65
57
66
main_weight = conv2d .weight .data .clone ()
58
67
main_bias = conv2d .bias .data .clone () if self .bias else None
59
68
self .main_weight = nn .Parameter (main_weight , requires_grad = True )
60
- self .main_bias = nn .Parameter (main_bias , requires_grad = True ) if main_bias is not None else None
69
+ self .main_bias = (
70
+ nn .Parameter (main_bias , requires_grad = True )
71
+ if main_bias is not None
72
+ else None
73
+ )
61
74
62
75
def _init_alignments (self , align_opts ):
63
- align_init_weight = align_opts .get (' init_weight' , 1.0 )
64
- if ' init_bias' in align_opts :
65
- align_init_bias = align_opts [' init_bias' ]
76
+ align_init_weight = align_opts .get (" init_weight" , 1.0 )
77
+ if " init_bias" in align_opts :
78
+ align_init_bias = align_opts [" init_bias" ]
66
79
align_bias = True
67
80
else :
68
81
align_bias = False
69
82
align_channles = self .in_channels
70
83
71
84
self .align_weights = nn .ParameterList (
72
85
[
73
- nn .Parameter (torch .ones ((1 , align_channles , 1 , 1 ),
74
- dtype = self .main_weight .dtype ,
75
- device = self .main_weight .device ) * align_init_weight ,
76
- requires_grad = True )
86
+ nn .Parameter (
87
+ torch .ones (
88
+ (1 , align_channles , 1 , 1 ),
89
+ dtype = self .main_weight .dtype ,
90
+ device = self .main_weight .device ,
91
+ )
92
+ * align_init_weight ,
93
+ requires_grad = True ,
94
+ )
77
95
for _ in range (self .branch_num + 1 )
78
96
]
79
97
)
80
- self .align_biases = nn .ParameterList (
81
- [
82
- nn .Parameter (torch .ones ((1 , align_channles , 1 , 1 ),
83
- dtype = self .main_weight .dtype ,
84
- device = self .main_weight .device ) * align_init_bias ,
85
- requires_grad = True )
86
- for _ in range (self .branch_num + 1 )
87
- ]
88
- ) if align_bias else None
98
+ self .align_biases = (
99
+ nn .ParameterList (
100
+ [
101
+ nn .Parameter (
102
+ torch .ones (
103
+ (1 , align_channles , 1 , 1 ),
104
+ dtype = self .main_weight .dtype ,
105
+ device = self .main_weight .device ,
106
+ )
107
+ * align_init_bias ,
108
+ requires_grad = True ,
109
+ )
110
+ for _ in range (self .branch_num + 1 )
111
+ ]
112
+ )
113
+ if align_bias
114
+ else None
115
+ )
89
116
90
117
def _init_aux_conv (self , aux_conv_opts ):
91
- if ' init' not in aux_conv_opts :
92
- aux_conv_opts [' init' ] = ' kaiming_normal_'
118
+ if " init" not in aux_conv_opts :
119
+ aux_conv_opts [" init" ] = " kaiming_normal_"
93
120
94
- aux_weight , aux_bias = self ._init_conv (init_type = aux_conv_opts [' init' ])
121
+ aux_weight , aux_bias = self ._init_conv (init_type = aux_conv_opts [" init" ])
95
122
self .aux_weight = nn .Parameter (aux_weight , requires_grad = True )
96
- self .aux_bias = nn .Parameter (torch .zeros_like (aux_bias ), requires_grad = True ) \
97
- if aux_conv_opts .get ('bias' , False ) else torch .zeros_like (aux_bias )
123
+ self .aux_bias = (
124
+ nn .Parameter (torch .zeros_like (aux_bias ), requires_grad = True )
125
+ if aux_conv_opts .get ("bias" , False )
126
+ else torch .zeros_like (aux_bias )
127
+ )
98
128
99
- def __init__ (self , conv2d , branch_num ,
100
- align_opts , aux_conv_opts = None ,
101
- forward_type = 'reparameterize' ):
129
+ def __init__ (
130
+ self ,
131
+ conv2d ,
132
+ branch_num ,
133
+ align_opts ,
134
+ aux_conv_opts = None ,
135
+ forward_type = "reparameterize" ,
136
+ ):
102
137
super ().__init__ ()
103
138
self .branch_num = branch_num
104
139
self .forward_type = forward_type
@@ -111,33 +146,36 @@ def __init__(self, conv2d, branch_num,
111
146
self ._init_aux_conv (aux_conv_opts )
112
147
113
148
self .cur_branch = - 1
114
- self .forward = self ._reparameterize_forward if forward_type == 'reparameterize' \
149
+ self .forward = (
150
+ self ._reparameterize_forward
151
+ if forward_type == "reparameterize"
115
152
else self ._trivial_forward
153
+ )
116
154
117
155
def switch_forward_type (self , * , trivial = False , reparameterize = False ):
118
156
assert not (trivial and reparameterize )
119
157
if trivial :
120
158
self .forward = self ._trivial_forward
121
- self .forward_type = ' trivial'
159
+ self .forward_type = " trivial"
122
160
elif reparameterize :
123
161
self .forward = self ._reparameterize_forward
124
- self .forward_type = ' reparameterize'
162
+ self .forward_type = " reparameterize"
125
163
126
164
@staticmethod
127
165
def _sequential_reparamterize (k1 , b1 , k2 , b2 ):
128
166
# k1, b1 is the weight and bias of alignment
129
167
def depthwise_to_normal (k , padding ):
130
168
k = k .reshape (- 1 )
131
169
k = torch .diag (k ).unsqueeze (- 1 ).unsqueeze (- 1 )
132
- k = F .pad (k , _reverse_repeat_tuple (padding , 2 ), mode = ' constant' , value = 0.0 )
170
+ k = F .pad (k , _reverse_repeat_tuple (padding , 2 ), mode = " constant" , value = 0.0 )
133
171
return k
134
172
135
173
def bias_pad (b , padding ):
136
- return F .pad (b , _reverse_repeat_tuple (padding , 2 ), mode = ' replicate' )
174
+ return F .pad (b , _reverse_repeat_tuple (padding , 2 ), mode = " replicate" )
137
175
138
176
padding = (k2 .shape [- 2 ] - 1 ) // 2 , (k2 .shape [- 1 ] - 1 ) // 2
139
177
k1 = depthwise_to_normal (k1 , padding )
140
- k = F .conv2d (k2 , k1 , stride = 1 , padding = ' same' )
178
+ k = F .conv2d (k2 , k1 , stride = 1 , padding = " same" )
141
179
b = F .conv2d (bias_pad (b1 , padding ), k2 , bias = b2 , stride = 1 ).reshape (- 1 )
142
180
return k , b
143
181
@@ -149,15 +187,20 @@ def _parallel_reparamterize(k1, b1, k2, b2):
149
187
def _weight_and_bias (self ):
150
188
index = self .cur_branch
151
189
align_weight = self .align_weights [index ]
152
- align_bias = self .align_biases [index ] if self .align_biases is not None \
190
+ align_bias = (
191
+ self .align_biases [index ]
192
+ if self .align_biases is not None
153
193
else torch .zeros_like (align_weight )
194
+ )
154
195
155
196
main_weight , main_bias = self ._sequential_reparamterize (
156
- align_weight , align_bias , self .main_weight , self .main_bias )
197
+ align_weight , align_bias , self .main_weight , self .main_bias
198
+ )
157
199
158
- if hasattr (self , ' aux_weight' ):
200
+ if hasattr (self , " aux_weight" ):
159
201
main_weight , main_bias = self ._parallel_reparamterize (
160
- main_weight , main_bias , self .aux_weight , self .aux_bias )
202
+ main_weight , main_bias , self .aux_weight , self .aux_bias
203
+ )
161
204
162
205
return main_weight , main_bias
163
206
@@ -170,16 +213,23 @@ def _reparameterize_forward(self, x):
170
213
def _trivial_forward (self , x ):
171
214
index = self .cur_branch
172
215
align_weight = self .align_weights [index ]
173
- align_bias = self .align_biases [index ] if self .align_biases is not None \
216
+ align_bias = (
217
+ self .align_biases [index ]
218
+ if self .align_biases is not None
174
219
else torch .zeros_like (align_weight )
220
+ )
175
221
176
222
aligned_x = x * align_weight + align_bias
177
223
padded_aligned_x = F .pad (aligned_x , self .padding , self .padding_mode , value = 0.0 )
178
- main_x = F .conv2d (padded_aligned_x , self .main_weight , self .main_bias , stride = self .stride )
224
+ main_x = F .conv2d (
225
+ padded_aligned_x , self .main_weight , self .main_bias , stride = self .stride
226
+ )
179
227
180
- if hasattr (self , ' aux_weight' ):
228
+ if hasattr (self , " aux_weight" ):
181
229
padded_x = F .pad (x , self .padding , self .padding_mode , value = 0.0 )
182
- aux_x = F .conv2d (padded_x , self .aux_weight , self .aux_bias , stride = self .stride )
230
+ aux_x = F .conv2d (
231
+ padded_x , self .aux_weight , self .aux_bias , stride = self .stride
232
+ )
183
233
main_x = main_x + aux_x
184
234
185
235
return main_x
@@ -188,24 +238,29 @@ def _trivial_forward_with_intermediate_features(self, x):
188
238
index = self .cur_branch
189
239
features = {}
190
240
align_weight = self .align_weights [index ]
191
- align_bias = self .align_biases [index ] if self .align_biases is not None \
241
+ align_bias = (
242
+ self .align_biases [index ]
243
+ if self .align_biases is not None
192
244
else torch .zeros_like (align_weight )
245
+ )
193
246
194
- features [' in_feat' ] = x
247
+ features [" in_feat" ] = x
195
248
aligned_x = x * align_weight + align_bias
196
- features [' aligned_feat' ] = aligned_x
249
+ features [" aligned_feat" ] = aligned_x
197
250
198
251
padded_aligned_x = F .pad (aligned_x , self .padding , self .padding_mode , value = 0.0 )
199
- main_x = F .conv2d (padded_aligned_x , self .main_weight , self .main_bias , stride = self .stride )
200
- features ['main_feat' ] = main_x
252
+ main_x = F .conv2d (
253
+ padded_aligned_x , self .main_weight , self .main_bias , stride = self .stride
254
+ )
255
+ features ["main_feat" ] = main_x
201
256
202
- if hasattr (self , ' aux_weight' ):
257
+ if hasattr (self , " aux_weight" ):
203
258
padded_x = F .pad (x , self .padding , self .padding_mode , value = 0.0 )
204
259
aux_x = F .conv2d (padded_x , self .aux_weight , self .aux_bias , stride = 1 )
205
- features [' aux_feat' ] = aux_x
260
+ features [" aux_feat" ] = aux_x
206
261
main_x = main_x + aux_x
207
262
208
- features [' out_feat' ] = aux_x
263
+ features [" out_feat" ] = aux_x
209
264
self .current_intermediate_features = features
210
265
return main_x
211
266
@@ -215,30 +270,30 @@ def __repr__(self):
215
270
extra_repr = self .extra_repr ()
216
271
# empty string will be split into list ['']
217
272
if extra_repr :
218
- extra_lines = extra_repr .split (' \n ' )
273
+ extra_lines = extra_repr .split (" \n " )
219
274
child_lines = []
220
275
lines = extra_lines + child_lines
221
276
222
- main_str = self ._get_name () + '('
277
+ main_str = self ._get_name () + "("
223
278
if lines :
224
279
# simple one-liner info, which most builtin Modules will use
225
280
if len (extra_lines ) == 1 and not child_lines :
226
281
main_str += extra_lines [0 ]
227
282
else :
228
- main_str += ' \n ' + ' \n ' .join (lines ) + ' \n '
283
+ main_str += " \n " + " \n " .join (lines ) + " \n "
229
284
230
- main_str += ')'
285
+ main_str += ")"
231
286
return main_str
232
287
233
288
def extra_repr (self ) -> str :
234
289
s = (
235
- ' {in_channels}, {out_channels}, kernel_size={kernel_size}, '
236
- ' stride={stride}, padding={padding}, padding_mode={padding_mode},\n '
237
- ' branch_num={branch_num}, align_opts={align_opts}\n '
238
- ' forward_type={forward_type}'
290
+ " {in_channels}, {out_channels}, kernel_size={kernel_size}, "
291
+ " stride={stride}, padding={padding}, padding_mode={padding_mode},\n "
292
+ " branch_num={branch_num}, align_opts={align_opts}\n "
293
+ " forward_type={forward_type}"
239
294
)
240
- if hasattr (self , ' aux_weight' ):
241
- s += ' , aux_conv_opts={aux_conv_opts}'
295
+ if hasattr (self , " aux_weight" ):
296
+ s += " , aux_conv_opts={aux_conv_opts}"
242
297
return s .format (** self .__dict__ )
243
298
244
299
@@ -287,7 +342,9 @@ def generalize_align_conv(align_conv: RepNRConv2d):
287
342
generalize_align_conv (m )
288
343
289
344
@staticmethod
290
- def _set_requires_grad (align_conv : RepNRConv2d , * , pretrain = True , finetune = False , aux = True ):
345
+ def _set_requires_grad (
346
+ align_conv : RepNRConv2d , * , pretrain = True , finetune = False , aux = True
347
+ ):
291
348
assert not (pretrain and finetune )
292
349
for i in range (align_conv .branch_num ):
293
350
align_conv .align_weights [i ].requires_grad_ (pretrain )
@@ -303,7 +360,7 @@ def _set_requires_grad(align_conv: RepNRConv2d, *, pretrain=True, finetune=False
303
360
align_conv .main_bias .requires_grad_ (pretrain )
304
361
305
362
# aux weight and bias
306
- if hasattr (align_conv , ' aux_weight' ) and aux :
363
+ if hasattr (align_conv , " aux_weight" ) and aux :
307
364
align_conv .aux_weight .requires_grad_ (finetune )
308
365
if isinstance (align_conv .aux_bias , nn .Parameter ):
309
366
align_conv .aux_bias .requires_grad_ (finetune )
@@ -319,7 +376,7 @@ def finetune(self, *, aux=False):
319
376
self ._set_requires_grad (m , pretrain = False , finetune = True , aux = aux )
320
377
321
378
def __repr__ (self ):
322
- return f' { self ._get_name ()} : { str (self .repnr_module )} '
379
+ return f" { self ._get_name ()} : { str (self .repnr_module )} "
323
380
324
381
@torch .no_grad ()
325
382
def deploy (self ):
@@ -333,7 +390,8 @@ def deploy(m_bak, m):
333
390
if isinstance (m , RepNRConv2d ):
334
391
weight , bias = m ._weight_and_bias
335
392
m_bak .weight .data = weight
336
- m_bak .bias .data = bias
393
+ if m_bak .bias is not None :
394
+ m_bak .bias .data = bias
337
395
return
338
396
if not has_repnr_conv (m ):
339
397
m_bak .load_state_dict (m .state_dict ())
0 commit comments