Skip to content

Commit e1fd80a

Browse files
committed
Comments
1 parent 4c86a86 commit e1fd80a

File tree

5 files changed

+68
-67
lines changed

5 files changed

+68
-67
lines changed

src/relay/qnn/op/requantize.cc

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -155,22 +155,17 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
155155
if (!IsEqualScalar(input_scale, output_scale)) {
156156
int32_t fixed_point_multiplier, shift;
157157
std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier);
158+
158159
const bool is_upward_rounding = (param->rounding == "UPWARD");
159160

160-
if (is_upward_rounding && fixed_point_multiplier == (1 << 30)) {
161-
// Power of 2 is determined by the fixed_point_multiplier == 1 << 30. In case of power of 2,
162-
// fixed point multiplier will represent a float value of 0.5. In fixed point, this is
163-
// represented by 1 << 30.
164-
scaled_int32_t = PowerOfTwoMultiply(scaled_int32_t, shift - 1);
165-
} else {
166-
// When using upward rounding (i.e., x.5 rounded to x+1), leverage
167-
// the FixedPointMultiply operator
168-
scaled_int32_t =
169-
(is_upward_rounding
170-
? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift)
171-
: FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape));
172-
}
161+
// When using upward rounding (i.e., x.5 rounded to x+1), leverage
162+
// the FixedPointMultiply operator
163+
scaled_int32_t =
164+
(is_upward_rounding
165+
? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift)
166+
: FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape));
173167
}
168+
174169
} else {
175170
// This is per-channel (per=axis) quantization.
176171
std::vector<double> double_multipliers;

src/relay/qnn/utils.cc

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,21 +56,6 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplie
5656
return std::make_pair(significand, exponent);
5757
}
5858

59-
Expr PowerOfTwoMultiply(Expr tensor, int32_t exp) {
60-
Expr out;
61-
if (exp > 0) {
62-
// power of 2 is greater than 0, apply left shift.
63-
out = LeftShift(tensor, MakeConstantScalar(DataType::Int(32), exp));
64-
} else {
65-
// power of 2 is less than 0, round and then apply right shift.
66-
exp = -exp;
67-
auto rounding_factor = 1 << (exp - 1);
68-
auto rounded_t = Add(tensor, MakeConstantScalar(DataType::Int(32), rounding_factor));
69-
out = RightShift(rounded_t, MakeConstantScalar(DataType::Int(32), exp));
70-
}
71-
return out;
72-
}
73-
7459
Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier,
7560
const Array<IndexExpr>& input_shape) {
7661
// Choose high precision datatype to be int64. This is for avoiding overflow

src/relay/qnn/utils.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,6 @@ static inline int64_t get_const_int(const tvm::PrimExpr& x) {
136136
*/
137137
Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier,
138138
const Array<IndexExpr>& input_shape);
139-
/*
140-
* \brief Mutiply an integer datatype tensor by a power of two.
141-
* \param tensor The quantized input tensor of dtype int32.
142-
* \param exp The exp or the power of 2 representing the number to be multiplied.
143-
* \return The sequence of Relay ops for power of two multiplication.
144-
*/
145-
Expr PowerOfTwoMultiply(Expr tensor, int32_t exp);
146139

147140
/*
148141
* \brief Fixed point multiplication between integer tensor with floating point

src/target/intrin_rule.cc

Lines changed: 59 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -128,37 +128,65 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.q_multiply_shift")
128128
PrimExpr q = call->args[2];
129129
PrimExpr s = call->args[3];
130130

131-
// Only int32 types are supported (any number of lanes is allowed)
132-
ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32);
133-
ICHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32);
134-
135-
DataType hp_dtype = DataType::Int(64, x.dtype().lanes());
136-
DataType lp_dtype = DataType::Int(32, x.dtype().lanes());
137-
138-
// 1) Calculating the integer multiplier and integer shift
139-
PrimExpr zero = make_const(s.dtype(), 0);
140-
PrimExpr left_shift = tir::Select(s > zero, s, zero);
141-
PrimExpr right_shift = tir::Select(s > zero, zero, -s);
142-
143-
// 2) Cast and Multiply the integer multiplier
144-
PrimExpr one = make_const(hp_dtype, 1);
145-
x = cast(hp_dtype, x);
146-
y = cast(hp_dtype, y);
147-
x = tir::Select(left_shift != zero, x << left_shift, x);
148-
149-
// 3) Perform the multiplication in higher precision.
150-
x = x * y;
151-
152-
// 4) Find the rounding scalar
153-
PrimExpr total_right_shift = right_shift + q;
154-
PrimExpr pos_rounding_value = (one << (total_right_shift - 1));
155-
x = x + pos_rounding_value;
156-
157-
// 5) Simply right shift the result to get the final output.
158-
x = x >> total_right_shift;
159-
160-
// 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
161-
*rv = cast(lp_dtype, x);
131+
// Lambda function to extract the int value from PrimExpr
132+
auto get_int_value = [](const PrimExpr node) {
133+
auto broadcast_node = node.as<BroadcastNode>();
134+
CHECK(broadcast_node != nullptr);
135+
auto int_node = broadcast_node->value.as<IntImmNode>();
136+
CHECK(int_node != nullptr);
137+
return int_node->value;
138+
};
139+
// Power of 2 is determined by the fixed_point_multiplier == 1 << 30. In case of power of 2,
140+
// fixed point multiplier will represent a float value of 0.5. In fixed point, this is
141+
// represented by 1 << 30.
142+
if (get_int_value(y) == (1 << 30)) {
143+
PrimExpr exp = s - 1;
144+
int exp_val = get_int_value(s) - 1;
145+
if (exp_val > 0) {
146+
// power of 2 is greater than 0, apply left shift.
147+
*rv = x << exp;
148+
} else {
149+
// power of 2 is less than 0, round and then apply right shift.
150+
DataType lp_dtype = DataType::Int(32, x.dtype().lanes());
151+
PrimExpr one = make_const(lp_dtype, 1);
152+
exp = -exp;
153+
PrimExpr rounding_factor = one << (exp - 1);
154+
PrimExpr rounded_t = x + rounding_factor;
155+
*rv = rounded_t >> exp;
156+
}
157+
} else {
158+
// Only int32 types are supported (any number of lanes is allowed)
159+
ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32);
160+
ICHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32);
161+
162+
DataType hp_dtype = DataType::Int(64, x.dtype().lanes());
163+
DataType lp_dtype = DataType::Int(32, x.dtype().lanes());
164+
165+
// 1) Calculating the integer multiplier and integer shift
166+
PrimExpr zero = make_const(s.dtype(), 0);
167+
PrimExpr left_shift = tir::Select(s > zero, s, zero);
168+
PrimExpr right_shift = tir::Select(s > zero, zero, -s);
169+
170+
// 2) Cast and Multiply the integer multiplier
171+
PrimExpr one = make_const(hp_dtype, 1);
172+
x = cast(hp_dtype, x);
173+
y = cast(hp_dtype, y);
174+
x = tir::Select(left_shift != zero, x << left_shift, x);
175+
176+
// 3) Perform the multiplication in higher precision.
177+
x = x * y;
178+
179+
// 4) Find the rounding scalar
180+
PrimExpr total_right_shift = right_shift + q;
181+
PrimExpr pos_rounding_value = (one << (total_right_shift - 1));
182+
x = x + pos_rounding_value;
183+
184+
// 5) Simply right shift the result to get the final output.
185+
x = x >> total_right_shift;
186+
187+
// 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
188+
*rv = cast(lp_dtype, x);
189+
}
162190
});
163191

164192
} // namespace intrin

tests/python/relay/test_op_qnn_requantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tvm import relay
2222
from tvm.contrib import graph_runtime
2323

24-
roundings = ["UPWARD", "TONEAREST"]
24+
roundings = ["UPWARD"]
2525

2626

2727
def verify(mod, goldens):

0 commit comments

Comments
 (0)