Skip to content

Commit 2be444f

Browse files
shoubhikzhiics
authored andcommitted
Improve the lowering of Qnn Dense (#4213)
* [QNN] Improving Dense lowering. * - Moving get_shape method to util - Finalizing the test cases and the code structure for optimized dense computation. * - Fixing cpplint. * - Addressing review comments. * - Renaming the variables correctly. * - Renaming the variables correctly.
1 parent 50e4aa0 commit 2be444f

File tree

6 files changed

+99
-56
lines changed

6 files changed

+99
-56
lines changed

include/tvm/relay/qnn/attrs.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ struct QnnDenseAttrs : public tvm::AttrsNode<QnnDenseAttrs> {
213213
int32_t input_zero_point;
214214
int32_t kernel_zero_point;
215215

216-
TVM_DECLARE_ATTRS(QnnDenseAttrs, "relay.attrs.qnn.QnnDenseAttrs") {
216+
TVM_DECLARE_ATTRS(QnnDenseAttrs, "relay.attrs.QnnDenseAttrs") {
217217
TVM_ATTR_FIELD(units)
218218
.describe("Number of hidden units of the dense transformation.");
219219
TVM_ATTR_FIELD(out_dtype)

python/tvm/relay/qnn/op/op_attrs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,7 @@
2222
@register_relay_attr_node
2323
class QnnConv2DAttrs(Attrs):
2424
"""Attributes for qnn.conv2d"""
25+
26+
@register_relay_attr_node
27+
class QnnDenseAttrs(Attrs):
28+
"""Attributes for qnn.dense"""

src/relay/qnn/op/convolution.cc

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,6 @@ using WorkloadType = std::tuple<int, int, int, int, int>;
7070
*/
7171
WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const QnnConv2DAttrs* param) {
7272
// Get conv parameters.
73-
auto get_shape = [](const Type& type) {
74-
auto input_tt = type.as<TensorTypeNode>();
75-
CHECK(input_tt != nullptr) << "Type information missing."
76-
<< " Please run infer_type pass.";
77-
return input_tt->shape;
78-
};
79-
8073
const auto in_shape = get_shape(arg_types[0]);
8174
int batch_size, in_channels;
8275
if (param->data_layout == "NCHW") {

src/relay/qnn/op/dense.cc

Lines changed: 87 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <tvm/relay/qnn/attrs.h>
3030
#include "../../op/nn/nn.h"
3131
#include "../../pass/pattern_util.h"
32+
#include "../util.h"
3233

3334
namespace tvm {
3435
namespace relay {
@@ -37,33 +38,27 @@ namespace qnn {
3738
// relay.op.qnn.dense
3839
TVM_REGISTER_NODE_TYPE(QnnDenseAttrs);
3940

40-
bool QnnDenseRel(const Array<Type>& types,
41-
int num_inputs,
42-
const Attrs& attrs,
41+
bool QnnDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
4342
const TypeReporter& reporter) {
4443
CHECK_EQ(types.size(), 3);
4544
const auto* data = types[0].as<TensorTypeNode>();
4645
const auto* weight = types[1].as<TensorTypeNode>();
4746
if (data == nullptr || weight == nullptr) return false;
4847
const auto* param = attrs.as<QnnDenseAttrs>();
49-
CHECK(param != nullptr) << "QnnConv2DAttrs cannot be nullptr.";
48+
CHECK(param != nullptr) << "QnnDenseAttrs cannot be nullptr.";
5049
CHECK(data->dtype == Int(8) || data->dtype == UInt(8))
51-
<< "Expected quantized dense type(int8, uint8) for input but was " << data->dtype;
50+
<< "Expected quantized dense type(int8, uint8) for input but was " << data->dtype;
5251
CHECK(weight->dtype == Int(8) || weight->dtype == UInt(8))
53-
<< "Expected quantized dense type(int8, uint8) for weight but was " << weight->dtype;
52+
<< "Expected quantized dense type(int8, uint8) for weight but was " << weight->dtype;
5453
CHECK(param->out_dtype == Int(32))
55-
<< "Expected quantized dense type(int32) for output but was " << param->out_dtype;
54+
<< "Expected quantized dense type(int32) for output but was " << param->out_dtype;
5655
CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";
5756
return DenseRel<QnnDenseAttrs>(types, num_inputs, attrs, reporter);
5857
}
5958

6059
// Positional relay function to create quantized dense operator used by frontend FFI.
61-
Expr MakeQuantizedDense(Expr data,
62-
Expr weight,
63-
IndexExpr units,
64-
int32_t input_zero_point,
65-
int32_t kernel_zero_point,
66-
DataType out_dtype) {
60+
Expr MakeQuantizedDense(Expr data, Expr weight, IndexExpr units, int32_t input_zero_point,
61+
int32_t kernel_zero_point, DataType out_dtype) {
6762
auto attrs = make_node<QnnDenseAttrs>();
6863
attrs->units = std::move(units);
6964
attrs->out_dtype = out_dtype;
@@ -73,40 +68,93 @@ Expr MakeQuantizedDense(Expr data,
7368
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
7469
}
7570

76-
/**
77-
* \brief Lowers Qnn convolution in terms of core operators in relay.
78-
* Mathematically it is equals to -
79-
* Dense((quantized_input - input_zero_point;int32), (quantized_kernel - kernel_zero_point; int32))
80-
*
81-
* \param attrs QnnDenseAttrs for Qnn Dense layer.
71+
Expr DenseFirstTerm(const Expr& quantized_data, const Expr& quantized_kernel,
72+
const QnnDenseAttrs* attrs) {
73+
return Dense(quantized_data, quantized_kernel, attrs->units, attrs->out_dtype);
74+
}
75+
76+
Expr DenseSecondTerm(const Expr& quantized_data, const Expr& zp_kernel) {
77+
Array<Integer> axes = {1};
78+
return Multiply(zp_kernel, Sum(Cast(quantized_data, Int(32)), axes, true, false));
79+
}
80+
81+
Expr DenseThirdTerm(const Expr& quantized_kernel, const Expr& zp_data) {
82+
Array<Integer> axes = {1};
83+
return Multiply(zp_data, Sum(Cast(quantized_kernel, Int(32)), axes, false, false));
84+
}
85+
86+
Expr DenseFourthTerm(const QnnDenseAttrs* attrs, int reduction_dim_size) {
87+
int32_t scalar_term = attrs->input_zero_point * attrs->kernel_zero_point * reduction_dim_size;
88+
return MakeConstantScalar(Int(32), scalar_term);
89+
}
90+
91+
/*
92+
* \brief Forward rewrite the qnn dense op.
93+
* \param attrs The QNN dense attrs.
8294
* \param new_args The new mutated args to the call node.
83-
* \param arg_types The data types of input and output.
84-
* \reutrn The sequence of Relay ops for qnn cov2d op.
95+
* \param arg_types The types of input and output.
96+
* \return The sequence of Relay ops for qnn cov2d op.
97+
* \note Lowering of the qnn.dense operator
98+
* A quantized tensor is represented in following manner
99+
* A = scale_a x (QA - zp_A)
100+
* where QA is quantized tensor, scale_a and zp_A are quantization
101+
* params.
102+
*
103+
* Quantized dense multiplies two quantized tensors and returns a
104+
* quantized tensor of default dtype of int32, with scale equaling to the
105+
* product of scales of input tensors, and a zero point of zero.
106+
*
107+
* The lowering for asymmetric quantized dense looks as follows. More details at
108+
* https://discuss.tvm.ai/t/tf-lite-quantized-conv2d-operator-conversion/2651/8
109+
* The computation gets unrolled into following 4 terms
110+
* C(m, n) = Sigma(k) (A(m, k) * W(n, k))
111+
*
112+
* RHS becomes
113+
* Sigma(k) ([QA(m, k) - zp_a] * [QW(n, k) - zp_w])
114+
*
115+
* Unrolling leads to following sequence
116+
* Sigma(k) QA(m, k) * QW(n, k) // Term1
117+
* - Sigma(k) zp_w * QA(m, k) // Term2
118+
* - Sigma(k) zp_a * QW(n, k) // Term3
119+
* - Sigma(k) * zp_a * zp_w // Term4
120+
*
121+
* Term3 and Term4 can be computed at compile time.
85122
*/
86-
Expr QnnDenseCanonicalize(const Attrs& attrs,
87-
const Array<Expr>& new_args,
123+
Expr QnnDenseCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
88124
const Array<tvm::relay::Type>& arg_types) {
89125
CHECK_EQ(new_args.size(), 2);
90126
Expr quantized_data = new_args[0];
91127
Expr quantized_kernel = new_args[1];
128+
129+
const auto in_shape = get_shape(arg_types[0]);
130+
const int reduction_dim_size = get_const_int(in_shape[1]);
131+
92132
const auto* qnn_dense_attrs = attrs.as<QnnDenseAttrs>();
93-
Expr quantized_data_int32 = Cast(quantized_data, Int(32));
94-
if (qnn_dense_attrs->input_zero_point != 0) {
95-
quantized_data_int32 = Subtract(quantized_data_int32,
96-
MakeConstantScalar(Int(32),
97-
qnn_dense_attrs->input_zero_point));
98-
}
99-
Expr quantized_kernel_int32 = Cast(quantized_kernel, Int(32));
100-
if (qnn_dense_attrs->kernel_zero_point != 0) {
101-
quantized_kernel_int32 = Subtract(quantized_kernel_int32,
102-
MakeConstantScalar(Int(32),
103-
qnn_dense_attrs->kernel_zero_point));
133+
auto zp_kernel = MakeConstantScalar(Int(32), qnn_dense_attrs->kernel_zero_point);
134+
auto zp_data = MakeConstantScalar(Int(32), qnn_dense_attrs->input_zero_point);
135+
136+
// Get all the terms as described in the comments.
137+
auto term1 = DenseFirstTerm(quantized_data, quantized_kernel, qnn_dense_attrs);
138+
auto term2 = DenseSecondTerm(quantized_data, zp_kernel);
139+
auto term3 = DenseThirdTerm(quantized_kernel, zp_data);
140+
auto term4 = DenseFourthTerm(qnn_dense_attrs, reduction_dim_size);
141+
142+
// Combine those 4 terms depending on the zero points to get the best lowering.
143+
if (qnn_dense_attrs->input_zero_point == 0 && qnn_dense_attrs->kernel_zero_point == 0) {
144+
// term 2, 3 and 4 become zero.
145+
return term1;
146+
} else if (qnn_dense_attrs->input_zero_point == 0 && qnn_dense_attrs->kernel_zero_point != 0) {
147+
// term 3 and term 4 become zero.
148+
return Subtract(term1, term2);
149+
} else if (qnn_dense_attrs->input_zero_point != 0 && qnn_dense_attrs->kernel_zero_point == 0) {
150+
// term 2 and term 4 become zero.
151+
return Subtract(term1, term3);
152+
} else {
153+
auto data_term = Subtract(term1, term2);
154+
// Putting constant terms together, so that constant folding can fold it.
155+
auto const_term = Subtract(term4, term3);
156+
return Add(data_term, const_term);
104157
}
105-
Expr int32_dense = Dense(quantized_data_int32,
106-
quantized_kernel_int32,
107-
qnn_dense_attrs->units,
108-
qnn_dense_attrs->out_dtype);
109-
return int32_dense;
110158
}
111159

112160
RELAY_REGISTER_OP("qnn.dense")

src/relay/qnn/util.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ namespace tvm {
3636
namespace relay {
3737
namespace qnn {
3838

39+
static inline Array<IndexExpr> get_shape(const Type& type) {
40+
auto input_tt = type.as<TensorTypeNode>();
41+
CHECK(input_tt != nullptr) << "Type information missing."
42+
<< " Please run infer_type pass.";
43+
return input_tt->shape;
44+
}
45+
3946
static inline const int32_t GetQmin(const DataType& dtype) {
4047
CHECK_LE(dtype.bits(), 32)
4148
<< "QNN ops support int32 or lower precision";

tests/python/relay/test_qnn_dense.py renamed to tests/python/relay/test_op_qnn_dense.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,29 +193,20 @@ def qnn_dense_driver(test_configuration):
193193

194194

195195
def test_qnn_dense_without_bias():
196-
uint32_output_without_bias_paramas = \
197-
make_uint_configuration(use_bias=False)
198196
int32_output_without_bias_params = \
199197
make_int_configuration(use_bias=False)
200-
qnn_dense_driver(uint32_output_without_bias_paramas)
201198
qnn_dense_driver(int32_output_without_bias_params)
202199

203200

204201
def test_qnn_dense_with_bias():
205-
uint32_output_with_bias_params = \
206-
make_uint_configuration(use_bias=True)
207202
int32_output_with_bias_params = \
208203
make_int_configuration(use_bias=True)
209-
qnn_dense_driver(uint32_output_with_bias_params)
210204
qnn_dense_driver(int32_output_with_bias_params)
211205

212206

213207
def test_qnn_dense_with_requantized_output():
214-
uint8_requantized_output_with_bias_params = \
215-
make_uint_configuration(use_bias=True, requantize_output=True)
216208
int8_requantized_output_with_bias_params = \
217209
make_int_configuration(use_bias=True, requantize_output=True)
218-
qnn_dense_driver(uint8_requantized_output_with_bias_params)
219210
qnn_dense_driver(int8_requantized_output_with_bias_params)
220211

221212

0 commit comments

Comments
 (0)