Skip to content

Commit ec78ab9

Browse files
committed
[QNN] Optimize requantize for power of 2 and bug in dequantize
1 parent f73a1f6 commit ec78ab9

File tree

6 files changed

+93
-10
lines changed

6 files changed

+93
-10
lines changed

src/relay/qnn/op/dequantize.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,
9696
expanded_input_zero_point = ExpandBiasToMatchAxis(input_zero_point, n_dim, {axis});
9797
}
9898

99-
auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), input_zero_point);
100-
auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), input_scale);
99+
auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), expanded_input_zero_point);
100+
auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), expanded_input_scale);
101101
return scaled_output;
102102
}
103103

src/relay/qnn/op/requantize.cc

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -155,17 +155,20 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
155155
if (!IsEqualScalar(input_scale, output_scale)) {
156156
int32_t fixed_point_multiplier, shift;
157157
std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier);
158-
159158
const bool is_upward_rounding = (param->rounding == "UPWARD");
160159

161-
// When using upward rounding (i.e., x.5 rounded to x+1), leverage
162-
// the FixedPointMultiply operator
163-
scaled_int32_t =
164-
(is_upward_rounding
165-
? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift)
166-
: FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape));
160+
if (is_upward_rounding && fixed_point_multiplier == (1 << 30)) {
161+
// Power of 2
162+
scaled_int32_t = PowerOfTwoMultiply(scaled_int32_t, shift - 1);
163+
} else {
164+
// When using upward rounding (i.e., x.5 rounded to x+1), leverage
165+
// the FixedPointMultiply operator
166+
scaled_int32_t =
167+
(is_upward_rounding
168+
? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift)
169+
: FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape));
170+
}
167171
}
168-
169172
} else {
170173
// This is per-channel (per=axis) quantization.
171174
std::vector<double> double_multipliers;

src/relay/qnn/util.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,21 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplie
5656
return std::make_pair(significand, exponent);
5757
}
5858

59+
Expr PowerOfTwoMultiply(Expr tensor, int32_t exp) {
60+
Expr out;
61+
if (exp > 0) {
62+
// power of 2 is greater than 0, apply left shift.
63+
out = LeftShift(tensor, MakeConstantScalar(DataType::Int(32), exp));
64+
} else {
65+
// power of 2 is less than 0, round and then apply right shift.
66+
exp = -exp;
67+
auto rounding_factor = 1 << (exp - 1);
68+
auto rounded_t = Add(tensor, MakeConstantScalar(DataType::Int(32), rounding_factor));
69+
out = RightShift(rounded_t, MakeConstantScalar(DataType::Int(32), exp));
70+
}
71+
return out;
72+
}
73+
5974
Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier,
6075
const Array<IndexExpr>& input_shape) {
6176
// Choose high precision datatype to be int64. This is for avoiding overflow

src/relay/qnn/util.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,14 @@ static inline int64_t get_const_int(const tvm::PrimExpr& x) {
136136
*/
137137
Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier,
138138
const Array<IndexExpr>& input_shape);
139+
/*
140+
* \brief Mutiply an integer datatype tensor by a power of two.
141+
* \param tensor The quantized input tensor of dtype int32.
142+
* \param exp The exp or the power of 2 representing the number to be multiplied.
143+
* \return The sequence of Relay ops for power of two multiplication.
144+
145+
*/
146+
Expr PowerOfTwoMultiply(Expr tensor, int32_t exp);
139147

140148
/*
141149
* \brief Fixed point multiplication between integer tensor with floating point

tests/python/relay/test_op_qnn_dequantize.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,22 @@ def test_channelwise_axis_1():
101101
)
102102

103103

104+
def test_channelwise_axis_0():
105+
data = np.array([0, 1, 2, 3, 4, 243, 247, 249, 250, 251]).astype("uint8").reshape((2, 5))
106+
output = np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32]).astype("float32").reshape((2, 5))
107+
quant_args = {
108+
"in_zero_point": np.array([127, 123]).astype("int32"),
109+
"in_scale": np.array([0.5, 0.25]).astype("float32"),
110+
}
111+
112+
dequantize_test_driver(
113+
in_dtype="uint8", quant_args=quant_args, in_data=data, verify_output_data=output, axis=0
114+
)
115+
116+
104117
if __name__ == "__main__":
105118
test_uint8_to_float32()
106119
test_int8_to_float32()
107120
test_int32_to_float32()
108121
test_channelwise_axis_1()
122+
test_channelwise_axis_0()

tests/python/relay/test_op_qnn_requantize.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,48 @@ def test_upscale():
204204
verify(mod, (golden_data, golden_output))
205205

206206

207+
def test_non_power_of_two():
208+
for rounding in roundings:
209+
mod = get_mod(
210+
data_shape=(32,),
211+
data_dtype="int32",
212+
out_dtype="int8",
213+
input_scale=1,
214+
output_scale=3,
215+
rounding=rounding,
216+
)
217+
218+
# Try positive values
219+
golden_data = np.multiply(np.arange(0, 32, 1).astype("int32"), 3)
220+
golden_output = np.arange(0, 32, 1)
221+
verify(mod, (golden_data, golden_output))
222+
223+
# Try negative values
224+
golden_data = np.multiply(np.arange(0, -32, -1).astype("int32"), 3)
225+
golden_output = np.arange(0, -32, -1)
226+
verify(mod, (golden_data, golden_output))
227+
228+
# Try a different scale
229+
mod = get_mod(
230+
data_shape=(32,),
231+
data_dtype="int32",
232+
out_dtype="int8",
233+
input_scale=3,
234+
output_scale=1,
235+
rounding=rounding,
236+
)
237+
238+
# Try positive values
239+
golden_data = np.arange(0, 32, 1).astype("int32")
240+
golden_output = np.multiply(golden_data, 3)
241+
verify(mod, (golden_data, golden_output))
242+
243+
# Try negative values
244+
golden_data = np.arange(0, -32, -1).astype("int32")
245+
golden_output = np.multiply(golden_data, 3)
246+
verify(mod, (golden_data, golden_output))
247+
248+
207249
def test_saturation():
208250
for rounding in roundings:
209251
mod = get_mod(
@@ -397,6 +439,7 @@ def test_per_channel_different_scale():
397439
test_same_scale()
398440
test_downscale()
399441
test_upscale()
442+
test_non_power_of_two()
400443
test_saturation()
401444
test_zero_point()
402445
test_per_channel_same_scale()

0 commit comments

Comments
 (0)