1818import tvm
1919from tvm import relay
2020from tvm .ir import TensorAffineType , TupleAffineType
21+ from tvm .tir import bijective_layout
2122from ..op import register_fake_quantization_to_integer
2223
2324
2425def fold_constant (expr ):
2526 return relay .transform .FoldConstantExpr (expr , tvm .IRModule ())
2627
2728
29+ def get_zeros (scale ):
30+ return fold_constant (relay .op .cast (relay .op .zeros_like (scale ), "int32" ))
31+
32+
33+ def infer_shape (expr ):
34+ return relay .transform .InferType ()(tvm .IRModule .from_expr (expr ))["main" ].body .checked_type .shape
35+
36+
2837@register_fake_quantization_to_integer ("qnn.dequantize" )
2938def dequantize (expr , type_map ):
3039 """Remove dequantize op"""
@@ -52,8 +61,13 @@ def quantize(expr, type_map):
5261 expr .args [1 ],
5362 expr .args [2 ],
5463 out_dtype = expr .attrs .out_dtype ,
64+ axis = t .axis ,
5565 )
56- return [out , TensorAffineType (expr .args [1 ], expr .args [2 ], expr .attrs .out_dtype )]
66+
67+ return [
68+ out ,
69+ TensorAffineType (expr .args [1 ], expr .args [2 ], expr .attrs .out_dtype , expr .attrs .axis ),
70+ ]
5771
5872
5973def register_unary_identity (op_name ):
@@ -94,14 +108,19 @@ def bias_add(expr, type_map):
94108 b_t = type_map [b ]
95109 in_scale = fold_constant (x_t .scale )
96110 in_zero_point = fold_constant (x_t .zero_point )
97- if not tvm .ir .structural_equal (x_t , b_t ):
111+ if not (
112+ tvm .ir .structural_equal (x_t .scale , b_t .scale )
113+ and tvm .ir .structural_equal (x_t .zero_point , b_t .zero_point )
114+ and tvm .ir .structural_equal (x_t .dtype , b_t .dtype )
115+ ):
98116 b = relay .qnn .op .requantize (
99117 b ,
100118 b_t .scale ,
101119 b_t .zero_point ,
102120 in_scale ,
103121 in_zero_point ,
104122 out_dtype = x_t .dtype ,
123+ axis = 0 ,
105124 )
106125 out = relay .op .nn .bias_add (x , b , ** expr .attrs )
107126 return [out , x_t ]
@@ -116,11 +135,13 @@ def conv2d(expr, type_map):
116135 x_t = type_map [x ]
117136 w_t = type_map [weight ]
118137 conv_scale = fold_constant (x_t .scale * w_t .scale )
119- conv_zp = relay . const ( 0 )
138+ conv_zp = get_zeros ( conv_scale )
120139 out = relay .qnn .op .conv2d (
121140 x , weight , x_t .zero_point , w_t .zero_point , x_t .scale , w_t .scale , ** attrs
122141 )
123- return [out , TensorAffineType (conv_scale , conv_zp , out .attrs .out_dtype )]
142+ out_layout = attrs ["out_layout" ] if attrs ["out_layout" ] != "" else attrs ["data_layout" ]
143+ out_axis = bijective_layout (out_layout , "NCHW" ).backward_index (list (range (4 )))[1 ]
144+ return [out , TensorAffineType (conv_scale , conv_zp , out .attrs .out_dtype , out_axis .value )]
124145
125146
126147@register_fake_quantization_to_integer ("nn.dense" )
@@ -132,11 +153,11 @@ def dense(expr, type_map):
132153 x_t = type_map [x ]
133154 w_t = type_map [weight ]
134155 dense_scale = fold_constant (x_t .scale * w_t .scale )
135- dense_zp = relay . const ( 0 )
156+ dense_zp = get_zeros ( dense_scale )
136157 out = relay .qnn .op .dense (
137158 x , weight , x_t .zero_point , w_t .zero_point , x_t .scale , w_t .scale , ** attrs
138159 )
139- return [out , TensorAffineType (dense_scale , dense_zp , out .attrs .out_dtype )]
160+ return [out , TensorAffineType (dense_scale , dense_zp , out .attrs .out_dtype , 1 )]
140161
141162
142163@register_fake_quantization_to_integer ("nn.batch_matmul" )
@@ -148,7 +169,7 @@ def batch_matmul(expr, type_map):
148169 matmul_scale = fold_constant (x_t .scale * y_t .scale )
149170 matmul_zp = relay .const (0 )
150171 out = relay .qnn .op .batch_matmul (x , y , x_t .zero_point , y_t .zero_point , x_t .scale , y_t .scale )
151- return [out , TensorAffineType (matmul_scale , matmul_zp , out .attrs .out_dtype )]
172+ return [out , TensorAffineType (matmul_scale , matmul_zp , out .attrs .out_dtype , x_t . axis )]
152173
153174
154175@register_fake_quantization_to_integer ("concatenate" )
@@ -198,19 +219,52 @@ def clip(expr, type_map):
198219 amax = expr .attrs .a_max
199220 scale = fold_constant (t .scale )
200221 z_p = fold_constant (t .zero_point )
201- if isinstance (scale , relay .expr .Constant ) and isinstance (z_p , relay .expr .Constant ):
222+ if (
223+ isinstance (scale , relay .expr .Constant )
224+ and scale .data .numpy ().size == 1
225+ and isinstance (z_p , relay .expr .Constant )
226+ and z_p .data .numpy ().size == 1
227+ ):
202228 scale = scale .data .numpy ().item ()
203229 z_p = z_p .data .numpy ().item ()
204230 new_min = int (amin / scale + z_p )
205231 new_max = int (amax / scale + z_p )
206232 out = relay .op .clip (arg , new_min , new_max )
207233 else :
208- amin = relay .op .round (relay .op .const (amin ) / scale + z_p )
209- amax = relay .op .round (relay .op .const (amax ) / scale + z_p )
210- out = relay .op .minimum (relay .op .maximum (arg , amin ), amax )
234+ if not isinstance (amin , relay .expr .Constant ):
235+ amin = relay .op .const (amin )
236+ if not isinstance (amax , relay .expr .Constant ):
237+ amax = relay .op .const (amax )
238+
239+ scale_shape = infer_shape (scale )
240+ if len (scale_shape ) > 0 and scale_shape [0 ] > 1 :
241+ b_shape = [1 ] * len (infer_shape (arg ))
242+ b_shape [t .axis ] = - 1
243+ amin = relay .op .reshape (relay .op .broadcast_to (amin , scale_shape ), b_shape )
244+ amax = relay .op .reshape (relay .op .broadcast_to (amax , scale_shape ), b_shape )
245+ amin = relay .qnn .op .quantize (amin , scale , z_p , t .axis , t .dtype )
246+ amax = relay .qnn .op .quantize (amax , scale , z_p , t .axis , t .dtype )
247+ out = relay .op .minimum (relay .op .maximum (arg , fold_constant (amin )), fold_constant (amax ))
248+
211249 return [out , t ]
212250
213251
252+ @register_fake_quantization_to_integer ("nn.relu" )
253+ def relu (expr , type_map ):
254+ """Rewrite a relu op"""
255+ arg = expr .args [0 ]
256+ t = type_map [arg ]
257+ scale_shape = infer_shape (t .scale )
258+ z_p = t .zero_point
259+ assert len (scale_shape ) <= 1
260+ if len (scale_shape ) == 1 and scale_shape [0 ] > 1 :
261+ b_shape = [1 ] * len (infer_shape (arg ))
262+ b_shape [t .axis ] = - 1
263+ z_p = relay .op .reshape (relay .op .broadcast_to (z_p , scale_shape ), b_shape )
264+ zero = relay .op .cast (z_p , t .dtype )
265+ return [relay .op .maximum (arg , fold_constant (zero )), t ]
266+
267+
214268@register_fake_quantization_to_integer ("nn.pad" )
215269def pad (expr , type_map ):
216270 """Rewite an nn.pad op"""
@@ -231,6 +285,7 @@ def pad(expr, type_map):
231285 t .scale ,
232286 t .zero_point ,
233287 out_dtype = t .dtype ,
288+ axis = pad_t .axis ,
234289 )
235290 else :
236291 ## If the pad-value is a constant, we need to quantize it
@@ -319,6 +374,7 @@ def binary(expr, type_map):
319374 out_t .scale ,
320375 out_t .zero_point ,
321376 out_dtype = out_t .dtype ,
377+ axis = left_t .axis ,
322378 )
323379
324380 if right_t != out_t :
@@ -329,6 +385,7 @@ def binary(expr, type_map):
329385 out_t .scale ,
330386 out_t .zero_point ,
331387 out_dtype = out_t .dtype ,
388+ axis = right_t .axis ,
332389 )
333390 out = op (left , right )
334391 return [out , out_t ]
0 commit comments