2525
2626#include " fake_quantization_to_integer.h"
2727
28+ #include < tvm/ir/affine_type.h>
29+ #include < tvm/relay/attrs/nn.h>
30+ #include < tvm/relay/dataflow_matcher.h>
2831#include < tvm/relay/expr.h>
2932#include < tvm/relay/expr_functor.h>
3033#include < tvm/relay/qnn/attrs.h>
34+ #include < tvm/relay/qnn/op/dequantize.h>
3135#include < tvm/relay/transform.h>
3236
3337#include < unordered_map>
@@ -37,7 +41,8 @@ namespace relay {
3741
3842/* Description of FakeQuantizationToInteger
3943 *
40- * The purpose of this pass is to find regions of the graph that follow
44+ * This pass consists of two parts, a basic one and an optional one.
45+ * The purpose of the basic part is to find regions of the graph that follow
4146 * the general pattern:
4247 *
4348 * x w
@@ -52,7 +57,7 @@ namespace relay {
5257 *
5358 * and convert them into subgraphs with actual integer operations on x and w
5459 *
55- * The pass does this via a multi-pass approach:
60+ * The basic part does this via a multi-pass approach:
5661 *
5762 * The main pass is a MixedModeMutator that traverses the full graph searching for
5863 * quantize operations
@@ -69,6 +74,24 @@ namespace relay {
6974 *
7075 * After the second and third passes run, the first pass replaces the quantize with the
7176 * rewritten subgraph and the processing continues
77+ *
78+ * The main idea of the optional part is to find and transform operations with dequantized inputs
79+ * one by one individually. Only operations from the allowed list are allowed. For example, if on
80+ * the above general pattern op2 is not registered with the FTVMFakeQuantizationToInteger
81+ * attribute, op1 operation can still be converted. Converted pattern below:
82+ *
83+ * x w
84+ * | |
85+ * \ /
86+ * op1
87+ * |
88+ * dq
89+ * |
90+ * op2
91+ * |
92+ * q
93+ *
94+ * The optional part works in the same multi-pass approach.
7295 */
7396
7497using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>;
@@ -270,8 +293,233 @@ class FakeQuantizationRewriter : public MixedModeMutator {
270293 const bool hard_fail_;
271294};
272295
296+ bool is_op_enabled_for_optional_fq2i (const CallNode* call_node) {
297+ const Op op = Downcast<Op>(call_node->op );
298+ static auto fqfq = Op::GetAttrMap<FTVMFakeQuantizationToInteger>(" FTVMFakeQuantizationToInteger" );
299+ static std::unordered_set<Op, tvm::ObjectHash, tvm::ObjectEqual> ops = {
300+ Op::Get (" reshape" ),
301+ Op::Get (" squeeze" ),
302+ Op::Get (" strided_slice" ),
303+ Op::Get (" transpose" ),
304+ Op::Get (" expand_dims" ),
305+ Op::Get (" nn.max_pool2d" ),
306+ Op::Get (" nn.batch_flatten" ),
307+ Op::Get (" nn.depth_to_space" ),
308+ Op::Get (" max" ),
309+ Op::Get (" min" ),
310+ Op::Get (" nn.avg_pool2d" ),
311+ Op::Get (" nn.global_avg_pool2d" ),
312+ Op::Get (" nn.bias_add" ),
313+ Op::Get (" nn.conv2d" ),
314+ Op::Get (" nn.conv2d_transpose" ),
315+ Op::Get (" nn.dense" ),
316+ Op::Get (" nn.batch_matmul" ),
317+ Op::Get (" split" ),
318+ Op::Get (" clip" ),
319+ Op::Get (" nn.relu" ),
320+ Op::Get (" nn.pad" ),
321+ Op::Get (" broadcast_to" ),
322+ Op::Get (" minimum" ),
323+ Op::Get (" maximum" )};
324+
325+ auto is_enabled = [&](const auto i) { return i == call_node->op ; };
326+ auto result = std::find_if (std::begin (ops), std::end (ops), is_enabled);
327+ return result != ops.end () && fqfq.count (Downcast<Op>(op));
328+ }
329+
330+ class OptionalSubgraphExtractor : public ExprVisitor {
331+ public:
332+ const ExprSet GetSubgraph (const Expr& expr) {
333+ expr_call_node_ = expr.as <CallNode>();
334+ ICHECK (expr_call_node_ != nullptr );
335+ ICHECK (is_op_enabled_for_optional_fq2i (expr_call_node_));
336+
337+ VisitExpr (expr);
338+
339+ ExprSet subgraph;
340+ if (is_fake_quantized_) {
341+ for (auto kv : this ->visit_counter_ ) {
342+ if (auto call_node = GetRef<ObjectRef>(kv.first ).as <CallNode>()) {
343+ if (call_node != expr_call_node_) {
344+ subgraph.insert (Downcast<Expr>(GetRef<ObjectRef>(kv.first )));
345+ }
346+ }
347+ }
348+ }
349+ return subgraph;
350+ }
351+ const AffineTypeMap GetAffineTypes () { return affine_types_; }
352+ void VisitExpr (const Expr& expr) override {
353+ // When looking for fake quantized subgraphs, we only support data-flow regions of the graph,
354+ // i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we
355+ // abort the rewrite.
356+ if (expr.as <CallNode>() == nullptr && expr.as <OpNode>() == nullptr &&
357+ expr.as <TupleNode>() == nullptr && expr.as <TupleGetItemNode>() == nullptr &&
358+ expr.as <ConstantNode>() == nullptr ) {
359+ DLOG (INFO) << " FakeQuantizationToInteger found a non - dataflow op inside a fake quantize "
360+ " region, aborting this rewrite" ;
361+ is_fake_quantized_ = false ;
362+ } else {
363+ ExprVisitor::VisitExpr (expr);
364+ }
365+ }
366+
367+ protected:
368+ void VisitExpr_ (const CallNode* call_node) override {
369+ if (call_node->op == dequantize_op_) {
370+ const auto * attrs = call_node->attrs .as <qnn::DequantizeAttrs>();
371+ ICHECK (attrs != nullptr );
372+
373+ affine_types_.Set (
374+ GetRef<Expr>(call_node),
375+ TensorAffineType (
376+ call_node->args [1 ], call_node->args [2 ],
377+ tvm::relay::transform::InferTypeLocal (call_node->args [0 ]).as <TensorTypeNode>()->dtype ,
378+ attrs->axis ));
379+ } else if (call_node == expr_call_node_) {
380+ for (auto arg : call_node->args ) {
381+ VisitExpr (arg);
382+ }
383+ } else {
384+ // run normally on everything else.
385+ ExprVisitor::VisitExpr_ (call_node);
386+ }
387+ }
388+
389+ const Op dequantize_op_ = Op::Get(" qnn.dequantize" );
390+ bool is_fake_quantized_ = true ;
391+ AffineTypeMap affine_types_;
392+ const CallNode* expr_call_node_ = nullptr ;
393+ };
394+
395+ class OptionalSubgraphMutator : public ExprMutator {
396+ public:
397+ OptionalSubgraphMutator (ExprSet subgraph, AffineTypeMap affine_types, bool hard_fail)
398+ : subgraph_(subgraph), affine_types_(affine_types), hard_fail_(hard_fail) {}
399+
400+ Expr MutateSubgraph (const Expr& expr) {
401+ if (subgraph_.size () == 0 ) {
402+ return expr;
403+ }
404+
405+ quantize_node_ = expr.as <CallNode>();
406+ ICHECK (quantize_node_);
407+ ICHECK (is_op_enabled_for_optional_fq2i (quantize_node_));
408+
409+ for (auto node : subgraph_) {
410+ const Op op = Downcast<Op>(node.as <CallNode>()->op );
411+
412+ if (node.as <CallNode>()->op != dequantize_op_) {
413+ // Only modify the subgraph if we have translation
414+ // rules for every op
415+ if (hard_fail_) {
416+ LOG (FATAL) << " Found no rewrite rule for " << AsText (op, false ) << std::endl;
417+ } else {
418+ DLOG (INFO) << " Found no rewrite rule for " << AsText (op, false ) << std::endl;
419+ return expr;
420+ }
421+ }
422+ }
423+ try {
424+ return Mutate (expr);
425+ } catch (std::exception& e) {
426+ if (hard_fail_) {
427+ throw e;
428+ } else {
429+ DLOG (INFO) << " Ran into an error rewriting a subgraph, skipping" << expr << std::endl;
430+ return expr;
431+ }
432+ }
433+ }
434+
435+ protected:
436+ Expr VisitExpr_ (const CallNode* call_node) {
437+ Expr out;
438+ static auto fqfq =
439+ Op::GetAttrMap<FTVMFakeQuantizationToInteger>(" FTVMFakeQuantizationToInteger" );
440+
441+ Op op = Downcast<Op>(call_node->op );
442+ if (fqfq.count (op)) {
443+ Expr expr;
444+ if (op == dequantize_op_) {
445+ expr = GetRef<Expr>(call_node);
446+ } else {
447+ expr = ExprMutator::VisitExpr_ (call_node);
448+ }
449+ // Call the rewrite
450+ Array<ObjectRef> vals = fqfq[op](expr, affine_types_);
451+ // Save the outputs of the rewrite
452+ ICHECK (vals.size () == 2 )
453+ << " got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for "
454+ << AsText (op, false );
455+ out = Downcast<Expr>(vals[0 ]);
456+
457+ affine_types_.Set (out, Downcast<AffineType>(vals[1 ]));
458+
459+ if (call_node == quantize_node_) {
460+ out = qnn::MakeDequantize (out, vals[1 ].as <TensorAffineTypeNode>()->scale ,
461+ vals[1 ].as <TensorAffineTypeNode>()->zero_point ,
462+ vals[1 ].as <TensorAffineTypeNode>()->axis );
463+ }
464+ } else {
465+ ICHECK (false ) << " When rewriting a fake quantized graph, found an invalid node "
466+ << AsText (GetRef<Expr>(call_node), false );
467+ }
468+ return out;
469+ }
470+
471+ Expr VisitExpr_ (const TupleNode* node) {
472+ Expr expr = ExprMutator::VisitExpr_ (node);
473+ auto new_node = expr.as <TupleNode>();
474+ Array<TensorAffineType> types;
475+ for (Expr field : new_node->fields ) {
476+ ICHECK (affine_types_[field].as <TensorAffineTypeNode>());
477+ types.push_back (Downcast<TensorAffineType>(affine_types_[field]));
478+ }
479+ affine_types_.Set (expr, TupleAffineType (types));
480+ return expr;
481+ }
482+
483+ Expr VisitExpr_ (const TupleGetItemNode* node) {
484+ Expr expr = ExprMutator::VisitExpr_ (node);
485+ auto tuple_type = affine_types_[expr.as <TupleGetItemNode>()->tuple ].as <TupleAffineTypeNode>();
486+ affine_types_.Set (expr, tuple_type->types [node->index ]);
487+ return expr;
488+ }
489+
490+ ExprSet subgraph_;
491+ AffineTypeMap affine_types_;
492+ const bool hard_fail_;
493+ const Op dequantize_op_ = Op::Get(" qnn.dequantize" );
494+ const CallNode* quantize_node_ = nullptr ;
495+ };
496+
497+ class OptionalFakeQuantizationRewriter : public MixedModeMutator {
498+ public:
499+ explicit OptionalFakeQuantizationRewriter (bool hard_fail) : hard_fail_(hard_fail) {}
500+
501+ protected:
502+ Expr Rewrite_ (const CallNode* pre , const Expr& post ) override {
503+ if (const CallNode* call_node = post .as <CallNode>()) {
504+ const Op op = Downcast<Op>(call_node->op );
505+ if (is_op_enabled_for_optional_fq2i (call_node)) {
506+ OptionalSubgraphExtractor extractor;
507+ ExprSet subgraph = extractor.GetSubgraph (post );
508+ AffineTypeMap affine_types = extractor.GetAffineTypes ();
509+ Expr out = OptionalSubgraphMutator (subgraph, affine_types, hard_fail_).MutateSubgraph (post );
510+ return out;
511+ }
512+ }
513+ return post ;
514+ }
515+ const bool hard_fail_;
516+ };
517+
273518Expr FakeQuantizationToInteger (const Expr& expr, const IRModule& mod, bool hard_fail) {
274- return FakeQuantizationRewriter (hard_fail).Mutate (expr);
519+ auto fq_expr = FakeQuantizationRewriter (hard_fail).Mutate (expr);
520+ auto fq_inferred_expr = tvm::relay::InferType (fq_expr);
521+ auto ofq_expr = OptionalFakeQuantizationRewriter (hard_fail).Mutate (fq_inferred_expr);
522+ return ofq_expr;
275523}
276524
277525namespace transform {
0 commit comments