@@ -156,6 +156,13 @@ def conv2d_rewrite(ref_call, new_args, ctx):
156156 if cnt < current_qconfig ().skip_k_conv :
157157 _set_conv_counter (cnt + 1 )
158158 return None
159+
160+ if current_qconfig ().skip_conv_layers is not None :
161+ leave_alone_indices = [int (x ) for x in current_qconfig ().skip_conv_layers ]
162+ if cnt in leave_alone_indices :
163+ _set_conv_counter (cnt + 1 )
164+ return None
165+
159166 _set_conv_counter (cnt + 1 )
160167
161168 lhs_expr , lhs_kind = _get_expr_kind (new_args [0 ])
@@ -168,6 +175,7 @@ def conv2d_rewrite(ref_call, new_args, ctx):
168175 rhs_expr = attach_simulated_quantize (rhs_expr , QAnnotateKind .WEIGHT )
169176
170177 expr = _forward_op (ref_call , [lhs_expr , rhs_expr ])
178+
171179 return QAnnotateExpr (expr , QAnnotateKind .ACTIVATION )
172180
173181
@@ -178,6 +186,11 @@ def dense_rewrite(ref_call, new_args, ctx):
178186 cnt = _conv_counter ()
179187 if cnt < current_qconfig ().skip_k_conv :
180188 return None
189+ if current_qconfig ().skip_conv_layers is not None :
190+ leave_alone_indices = [int (x ) for x in current_qconfig ().skip_conv_layers ]
191+ if cnt - 1 in leave_alone_indices :
192+ return None
193+
181194 lhs_expr , lhs_kind = _get_expr_kind (new_args [0 ])
182195 rhs_expr , rhs_kind = _get_expr_kind (new_args [1 ])
183196
@@ -194,8 +207,13 @@ def dense_rewrite(ref_call, new_args, ctx):
194207@register_annotate_function ("multiply" )
195208def multiply_rewrite (ref_call , new_args , ctx ):
196209 """Rewrite function for multiply."""
197- if _conv_counter () <= current_qconfig ().skip_k_conv :
210+ cnt = _conv_counter ()
211+ if cnt <= current_qconfig ().skip_k_conv :
198212 return None
213+ if current_qconfig ().skip_conv_layers is not None :
214+ leave_alone_indices = [int (x ) for x in current_qconfig ().skip_conv_layers ]
215+ if cnt - 1 in leave_alone_indices :
216+ return None
199217
200218 lhs_expr , lhs_kind = _get_expr_kind (new_args [0 ])
201219 rhs_expr , rhs_kind = _get_expr_kind (new_args [1 ])
@@ -216,8 +234,13 @@ def multiply_rewrite(ref_call, new_args, ctx):
216234@register_annotate_function ("add" )
217235def add_rewrite (ref_call , new_args , ctx ):
218236 """Rewrite function for add."""
219- if _conv_counter () <= current_qconfig ().skip_k_conv :
237+ cnt = _conv_counter ()
238+ if cnt <= current_qconfig ().skip_k_conv :
220239 return None
240+ if current_qconfig ().skip_conv_layers is not None :
241+ leave_alone_indices = [int (x ) for x in current_qconfig ().skip_conv_layers ]
242+ if cnt - 1 in leave_alone_indices :
243+ return None
221244
222245 lhs_expr , lhs_kind = _get_expr_kind (new_args [0 ])
223246 rhs_expr , rhs_kind = _get_expr_kind (new_args [1 ])
@@ -244,8 +267,13 @@ def add_rewrite(ref_call, new_args, ctx):
244267
245268def identity_rewrite (ref_call , new_args , ctx ):
246269 """Simply forward the original operation"""
247- if _conv_counter () <= current_qconfig ().skip_k_conv :
270+ cnt = _conv_counter ()
271+ if cnt <= current_qconfig ().skip_k_conv :
248272 return None
273+ if current_qconfig ().skip_conv_layers is not None :
274+ leave_alone_indices = [int (x ) for x in current_qconfig ().skip_conv_layers ]
275+ if cnt - 1 in leave_alone_indices :
276+ return None
249277
250278 x_expr , x_kind = _get_expr_kind (new_args [0 ])
251279 if x_kind is None :
@@ -262,8 +290,14 @@ def identity_rewrite(ref_call, new_args, ctx):
262290
263291def pool2d_rewrite (ref_call , new_args , ctx ):
264292 """Rewrite function for max pool2d"""
265- if _conv_counter () <= current_qconfig ().skip_k_conv :
293+ cnt = _conv_counter ()
294+ if cnt <= current_qconfig ().skip_k_conv :
266295 return None
296+ if current_qconfig ().skip_conv_layers is not None :
297+ leave_alone_indices = [int (x ) for x in current_qconfig ().skip_conv_layers ]
298+ if cnt - 1 in leave_alone_indices :
299+ return None
300+
267301 expr , x_kind = _get_expr_kind (new_args [0 ])
268302
269303 if x_kind is None :
@@ -280,8 +314,13 @@ def pool2d_rewrite(ref_call, new_args, ctx):
280314@register_annotate_function ("concatenate" )
281315def concatenate_rewrite (ref_call , new_args , ctx ):
282316 """Rewrite function for concatenate"""
283- if _conv_counter () <= current_qconfig ().skip_k_conv :
317+ cnt = _conv_counter ()
318+ if cnt <= current_qconfig ().skip_k_conv :
284319 return None
320+ if current_qconfig ().skip_conv_layers is not None :
321+ leave_alone_indices = [int (x ) for x in current_qconfig ().skip_conv_layers ]
322+ if cnt - 1 in leave_alone_indices :
323+ return None
285324
286325 input_tuple = new_args [0 ]
287326 expr_list = [_get_expr_kind (x )[0 ] for x in input_tuple ]
0 commit comments