@@ -175,6 +175,28 @@ def conv2d_rewrite(ref_call, new_args, ctx):
175
175
return QAnnotateExpr (expr , QAnnotateKind .ACTIVATION )
176
176
177
177
178
+ @register_annotate_function ("nn.conv1d" )
179
+ def conv1d_rewrite (ref_call , new_args , ctx ):
180
+ """Rewrite function for conv1d. Lhs of conv will be quantized to
181
+ input field, and rhs of conv will be quantized to weight field.
182
+ Output would be in activation field"""
183
+ if quantize_context ().check_to_skip (ref_call ):
184
+ return None
185
+
186
+ lhs_expr , lhs_kind = _get_expr_kind (new_args [0 ])
187
+ rhs_expr , rhs_kind = _get_expr_kind (new_args [1 ])
188
+
189
+ if lhs_kind is None or lhs_kind == QAnnotateKind .ACTIVATION :
190
+ lhs_expr = attach_simulated_quantize (lhs_expr , QAnnotateKind .INPUT )
191
+
192
+ assert rhs_kind is None
193
+ rhs_expr = attach_simulated_quantize (rhs_expr , QAnnotateKind .WEIGHT )
194
+
195
+ expr = _forward_op (ref_call , [lhs_expr , rhs_expr ])
196
+
197
+ return QAnnotateExpr (expr , QAnnotateKind .ACTIVATION )
198
+
199
+
178
200
@register_annotate_function ("nn.dense" )
179
201
def dense_rewrite (ref_call , new_args , ctx ):
180
202
"""Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of
@@ -289,6 +311,8 @@ def identity_rewrite(ref_call, new_args, ctx):
289
311
register_annotate_function ("nn.relu" , identity_rewrite )
290
312
register_annotate_function ("strided_slice" , identity_rewrite )
291
313
register_annotate_function ("nn.avg_pool2d" , identity_rewrite )
314
+ register_annotate_function ("nn.batch_flatten" , identity_rewrite )
315
+ register_annotate_function ("transpose" , identity_rewrite )
292
316
register_annotate_function ("annotation.stop_fusion" , identity_rewrite )
293
317
294
318
@@ -311,6 +335,25 @@ def pool2d_rewrite(ref_call, new_args, ctx):
311
335
register_annotate_function ("nn.max_pool2d" , pool2d_rewrite )
312
336
313
337
338
+ def pool1d_rewrite (ref_call , new_args , ctx ):
339
+ """Rewrite function for max pool1d"""
340
+ if quantize_context ().check_to_skip (ref_call ):
341
+ return None
342
+
343
+ expr , x_kind = _get_expr_kind (new_args [0 ])
344
+
345
+ if x_kind is None :
346
+ return None
347
+ if x_kind == QAnnotateKind .ACTIVATION :
348
+ expr = attach_simulated_quantize (expr , QAnnotateKind .INPUT )
349
+
350
+ expr = _forward_op (ref_call , [expr ])
351
+ return QAnnotateExpr (expr , QAnnotateKind .INPUT )
352
+
353
+
354
+ register_annotate_function ("nn.max_pool1d" , pool1d_rewrite )
355
+
356
+
314
357
@register_annotate_function ("annotation.cast_hint" )
315
358
def cast_hint_rewrite (ref_call , new_args , ctx ):
316
359
"""Rewrite function to force cast"""
0 commit comments