3131#include < tvm/relay/attrs/annotation.h>
3232#include " ./quantize.h"
3333#include " ../pattern_util.h"
34+ #include " ../../qnn/util.h"
3435
3536namespace tvm {
3637namespace relay {
@@ -97,7 +98,9 @@ inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
9798
9899
99100/* calculate `data * s1 / s2`, use shift if possible */
100- inline Expr MulAndDiv (Expr data, float s1, float s2, DataType dtype) {
101+ inline Expr MulAndDiv (Expr data, float s1, float s2, DataType dtype,
102+ const Array<IndexExpr> &data_shape) {
103+ const QConfig& cfg = QConfig::Current ();
101104 // here we assume the dtype of data is dtype activation
102105 if (s1 == s2) return data;
103106
@@ -110,9 +113,8 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
110113 } else if (static_cast <int >(factor) == factor) {
111114 return Multiply (data, MakeConstantScalar (dtype, factor));
112115 } else {
113- data = Cast (data, Float (32 ));
114- data = Multiply (data, MakeConstantScalar (Float (32 ), factor));
115- return Cast (Round (data), dtype);
116+ data = qnn::FixedPointMultiply (Cast (data, Int (64 )), factor, data_shape, cfg->rounding );
117+ return Cast (data, dtype);
116118 }
117119}
118120
@@ -164,11 +166,12 @@ Expr QuantizeRealize(const Call& ref_call,
164166 data = Clip (data, clip_min_imm, clip_max_imm);
165167 return QRealizeIntExprNode::make (data, dom_scale, n->dtype );
166168 } else {
167- // float computation
168- data = Cast (data, Float (32 ));
169- Expr scaled_data = Multiply (data, Divide (n->dom_scale , dom_scale));
170- Expr round_data = Clip (Round (scaled_data), clip_min_imm, clip_max_imm);
171- return QRealizeIntExprNode::make (round_data, dom_scale, Float (32 ));
169+ data = Cast (data, Int (64 ));
170+ data = qnn::FixedPointMultiply (data, idom_scale_imm / odom_scale_imm,
171+ ref_call->type_as <TensorTypeNode>()->shape ,
172+ cfg->rounding );
173+ data = Cast (Clip (data, clip_min_imm, clip_max_imm), n->dtype );
174+ return QRealizeIntExprNode::make (data, dom_scale, n->dtype );
172175 }
173176 }
174177
@@ -355,7 +358,7 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args
355358 Expr dom_scale = MakeConstantScalar (Float (32 ), s);
356359 for (size_t i = 0 ; i < ret.size (); ++i) {
357360 float cur_s = GetScalarFromConstant<float >(nptrs[i]->dom_scale );
358- ret.Set (i, MulAndDiv (ret[i], cur_s, s, dtype));
361+ ret.Set (i, MulAndDiv (ret[i], cur_s, s, dtype, ref_args[i]-> type_as <TensorTypeNode>()-> shape ));
359362 }
360363
361364 *dtype_ptr = dtype;
0 commit comments