Skip to content

Commit c2b3c4d

Browse files
AndrewZhaoLuopfk-beta
authored andcommitted
[QNN] Register a bunch of unary elementwise ops (apache#10086)
* 0;276;0cinitial commit * register a bunch of ops * unary ops * add a bunch of tests * 0;276;0crefactor tests * add tests to qnn * comments on macros * add back in log to pattern utils * update floating point func description * proper creating of calls to quantize and dequantize * fix lowering process for using dequantize and quantize ops
1 parent b5e13a8 commit c2b3c4d

File tree

9 files changed

+553
-192
lines changed

9 files changed

+553
-192
lines changed

python/tvm/relay/qnn/op/qnn.py

Lines changed: 179 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@
2222
import tvm
2323
import tvm.ir
2424
from tvm import relay
25-
from tvm.runtime import Object
2625
from tvm.relay.expr import Tuple, TupleWrapper
2726
from tvm.relay.op.nn.utils import get_pad_tuple2d
28-
from tvm.topi.nn.qnn import SQNN_DTYPE_TO_CODE
27+
from tvm.runtime import Object
2928
from tvm.target import Target
29+
from tvm.topi.nn.qnn import SQNN_DTYPE_TO_CODE
3030
from tvm.topi.x86.utils import target_has_sse41
31+
3132
from ... import op as reg
3233
from ...op import OpPattern
33-
from . import _make
34-
from . import _requantize
34+
from . import _make, _requantize
3535

3636

3737
@tvm._ffi.register_object("relay.qnn.op.RequantizeConfig")
@@ -750,6 +750,111 @@ def mul(
750750
)
751751

752752

753+
def tanh(x, scale, zero_point, output_scale, output_zero_point):
754+
"""Quantized tanh.
755+
756+
Parameters
757+
----------
758+
x : relay.Expr
759+
The quantized input tensor.
760+
761+
scale: relay.Expr
762+
The scale of the quantized expr.
763+
764+
zero_point: relay.Expr
765+
The zero point of quantized expr.
766+
767+
output_scale: relay.Expr
768+
The scale of the output quantized expr.
769+
770+
output_zero_point: relay.Expr
771+
The zero point of output quantized expr.
772+
773+
Returns
774+
-------
775+
result : relay.Expr
776+
The computed result.
777+
778+
"""
779+
return _make.tanh(
780+
x,
781+
scale,
782+
zero_point,
783+
output_scale,
784+
output_zero_point,
785+
)
786+
787+
788+
def exp(x, scale, zero_point, output_scale, output_zero_point):
789+
"""Quantized exponential function.
790+
791+
Parameters
792+
----------
793+
x : relay.Expr
794+
The quantized input tensor.
795+
796+
scale: relay.Expr
797+
The scale of the quantized expr.
798+
799+
zero_point: relay.Expr
800+
The zero point of quantized expr.
801+
802+
output_scale: relay.Expr
803+
The scale of the output quantized expr.
804+
805+
output_zero_point: relay.Expr
806+
The zero point of output quantized expr.
807+
808+
Returns
809+
-------
810+
result : relay.Expr
811+
The computed result.
812+
813+
"""
814+
return _make.exp(
815+
x,
816+
scale,
817+
zero_point,
818+
output_scale,
819+
output_zero_point,
820+
)
821+
822+
823+
def sqrt(x, scale, zero_point, output_scale, output_zero_point):
824+
"""Quantized square root.
825+
826+
Parameters
827+
----------
828+
x : relay.Expr
829+
The quantized input tensor.
830+
831+
scale: relay.Expr
832+
The scale of the quantized expr.
833+
834+
zero_point: relay.Expr
835+
The zero point of quantized expr.
836+
837+
output_scale: relay.Expr
838+
The scale of the output quantized expr.
839+
840+
output_zero_point: relay.Expr
841+
The zero point of output quantized expr.
842+
843+
Returns
844+
-------
845+
result : relay.Expr
846+
The computed result.
847+
848+
"""
849+
return _make.sqrt(
850+
x,
851+
scale,
852+
zero_point,
853+
output_scale,
854+
output_zero_point,
855+
)
856+
857+
753858
def rsqrt(x, scale, zero_point, output_scale, output_zero_point):
754859
"""Quantized reciprocal square root.
755860
@@ -785,6 +890,76 @@ def rsqrt(x, scale, zero_point, output_scale, output_zero_point):
785890
)
786891

787892

893+
def erf(x, scale, zero_point, output_scale, output_zero_point):
894+
"""Quantized error function.
895+
896+
Parameters
897+
----------
898+
x : relay.Expr
899+
The quantized input tensor.
900+
901+
scale: relay.Expr
902+
The scale of the quantized expr.
903+
904+
zero_point: relay.Expr
905+
The zero point of quantized expr.
906+
907+
output_scale: relay.Expr
908+
The scale of the output quantized expr.
909+
910+
output_zero_point: relay.Expr
911+
The zero point of output quantized expr.
912+
913+
Returns
914+
-------
915+
result : relay.Expr
916+
The computed result.
917+
918+
"""
919+
return _make.erf(
920+
x,
921+
scale,
922+
zero_point,
923+
output_scale,
924+
output_zero_point,
925+
)
926+
927+
928+
def sigmoid(x, scale, zero_point, output_scale, output_zero_point):
929+
"""Quantized sigmoid.
930+
931+
Parameters
932+
----------
933+
x : relay.Expr
934+
The quantized input tensor.
935+
936+
scale: relay.Expr
937+
The scale of the quantized expr.
938+
939+
zero_point: relay.Expr
940+
The zero point of quantized expr.
941+
942+
output_scale: relay.Expr
943+
The scale of the output quantized expr.
944+
945+
output_zero_point: relay.Expr
946+
The zero point of output quantized expr.
947+
948+
Returns
949+
-------
950+
result : relay.Expr
951+
The computed result.
952+
953+
"""
954+
return _make.sigmoid(
955+
x,
956+
scale,
957+
zero_point,
958+
output_scale,
959+
output_zero_point,
960+
)
961+
962+
788963
def subtract(
789964
lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
790965
):

src/relay/qnn/op/op_common.h

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <tvm/relay/op.h>
2929
#include <tvm/relay/op_attr_types.h>
3030
#include <tvm/relay/qnn/attrs.h>
31+
#include <tvm/relay/qnn/transform.h>
3132

3233
#include <vector>
3334

@@ -289,6 +290,98 @@ static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, con
289290
.set_attr<TNonComputational>("TNonComputational", true) \
290291
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnBinaryBroadcastLayout)
291292

293+
static inline bool QnnElementwiseUnaryFuncRel(const Array<Type>& types, int num_inputs,
294+
const Attrs& attrs, const TypeReporter& reporter) {
295+
// Expected Types: data, scale, zero_point, output_scale, output_zero_point
296+
ICHECK_EQ(types.size(), 6);
297+
const auto* x = types[0].as<TensorTypeNode>();
298+
if (x == nullptr) return false;
299+
ICHECK(x->dtype == DataType::Int(8) || x->dtype == DataType::UInt(8))
300+
<< "Expected quantized type(int8, uint8) for input but was " << x->dtype;
301+
302+
// Check the types of scale and zero points.
303+
for (size_t i = 1; i < 5; ++i) {
304+
if (types[i].as<IncompleteTypeNode>()) {
305+
return false;
306+
}
307+
}
308+
ICHECK(IsScalarType(types[1], DataType::Float(32))); // scale
309+
ICHECK(IsScalarType(types[2], DataType::Int(32))); // zero_point
310+
ICHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale
311+
ICHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point
312+
313+
// Assign types for scale and zero points.
314+
reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // scale
315+
reporter->Assign(types[2], TensorType({}, DataType::Int(32))); // zero_point
316+
reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // output_scale
317+
reporter->Assign(types[4], TensorType({}, DataType::Int(32))); // output_zero_point
318+
319+
// Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
320+
// IdentityRel infer type function.
321+
Array<Type> tensor_types = {types[0], types[5]};
322+
return IdentityRel(tensor_types, 2, attrs, reporter);
323+
}
324+
325+
static inline Expr LegalizeExpr(const Expr& expr) {
326+
// Canonicalizations should not contain qnn ops, so use this
327+
// to lower expressions automatically after using things like qnn.dequantize
328+
// in the lowering process.
329+
auto mod = IRModule::FromExpr(expr);
330+
mod = transform::Legalize()(mod);
331+
if (expr.as<FunctionNode>()) {
332+
return mod->Lookup("main");
333+
} else {
334+
return mod->Lookup("main").as<FunctionNode>()->body;
335+
}
336+
}
337+
338+
/*! Quick helper macro
339+
* - Expose a positional make function to construct the node.
340+
* - Register op to the registry.
341+
*
342+
* For Unary Operators which also take in QParams.
343+
*
344+
* \param OpName the name of registry.
345+
*/
346+
#define QNN_CREATE_UNARY_ELEMENTWISE_OP(OpName) \
347+
TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName) \
348+
.set_body_typed( \
349+
[](Expr x, Expr scale, Expr zero_point, Expr output_scale, Expr output_zero_point) { \
350+
return Call(Op::Get("qnn." OpName), \
351+
{x, scale, zero_point, output_scale, output_zero_point}, Attrs(), {}); \
352+
}); \
353+
\
354+
RELAY_REGISTER_OP("qnn." OpName) \
355+
.describe("Elementwise " OpName " for quantized tensors.") \
356+
.set_num_inputs(5) \
357+
.add_argument("data", "Quantized Tensor", "The input data.") \
358+
.add_argument("scale", "Tensor", "The quantization scale of the input tensor.") \
359+
.add_argument("zero_point", "Tensor", "The quantization zero_point of the input tensor.") \
360+
.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") \
361+
.add_argument("output_zero_point", "Tensor", \
362+
"The quantization zero_point of the output tensor.") \
363+
.set_support_level(11) \
364+
.add_type_rel("qnn." OpName, QnnElementwiseUnaryFuncRel) \
365+
.set_attr<TNonComputational>("TNonComputational", true)
366+
367+
/*! Quick helper macro
368+
* Create a default canonicalization for a QNN operator, which dequantizes the operator
369+
* runs the calculation using the provided Call func, and then requantizes.
370+
*
371+
* FloatingPointFunc is usually a handle from "src/relay/transforms/pattern_utils.h"
372+
*
373+
* \param FloatingPointFunc the floating point function with function signature `Expr Erf(Expr e)`
374+
*/
375+
#define QNN_UNARY_OP_DEFAULT_CANONICALIZATION(FloatingPointFunc) \
376+
[](const Attrs& attrs, const Array<Expr>& new_args, const Array<tvm::relay::Type>& arg_types) { \
377+
QnnUnaryOpArguments args(new_args); \
378+
QnnUnaryOpTensorType input_type(arg_types, 0); \
379+
Expr dequantized_arg = MakeDequantize(args.x, args.scale, args.zero_point, -1); \
380+
Expr output = FloatingPointFunc(dequantized_arg); \
381+
Expr result = \
382+
MakeQuantize(output, args.output_scale, args.output_zero_point, -1, input_type.dtype); \
383+
return LegalizeExpr(result); \
384+
}
292385
} // namespace qnn
293386
} // namespace relay
294387
} // namespace tvm

0 commit comments

Comments
 (0)