Skip to content

Commit c595a4d

Browse files
committed
Add a conversion of individual operations in FQ2I pass.
1 parent 55849e6 commit c595a4d

File tree

6 files changed

+528
-8
lines changed

6 files changed

+528
-8
lines changed

include/tvm/relay/dataflow_matcher.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,17 @@ Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr, IRModule mod
106106
*/
107107
Expr PartitionPattern(DFPattern pattern, Expr expr, Map<String, ObjectRef> attrs, PackedFunc check);
108108

109+
/*!
110+
* \brief Infer the type of an expression.
111+
*
112+
* \param expr The expression to rewrite
113+
*
114+
* \return Return An Expr with unambiguous type information filled in, as well as it's
115+
* checked type field populated with the result type.
116+
*
117+
*/
118+
Expr InferType(const Expr& expr);
119+
109120
} // namespace relay
110121
} // namespace tvm
111122

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/relay/executor.h
22+
* \brief Relay dequantize.
23+
*/
24+
#ifndef TVM_RELAY_QNN_OP_DEQUANTIZE_H_
25+
#define TVM_RELAY_QNN_OP_DEQUANTIZE_H_
26+
27+
#include <tvm/relay/expr.h>
28+
29+
namespace tvm {
30+
namespace relay {
31+
namespace qnn {
32+
33+
Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis);
34+
35+
} // namespace qnn
36+
} // namespace relay
37+
} // namespace tvm
38+
39+
#endif // TVM_RELAY_QNN_OP_DEQUANTIZE_H_

python/tvm/relay/transform/fake_quantization_to_integer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,16 @@ def global_avgpool2d(expr, type_map):
130130
return [out, t]
131131

132132

133+
@register_fake_quantization_to_integer("broadcast_to")
134+
def broadcast_to(expr, type_map):
135+
"""Rewrite a broadcast_to op"""
136+
arg = expr.args[0]
137+
t = type_map[arg]
138+
shape = expr.attrs.shape
139+
out = relay.op.broadcast_to(arg, shape)
140+
return [out, t]
141+
142+
133143
@register_fake_quantization_to_integer("rsqrt")
134144
def rsqrt(expr, type_map):
135145
"""Rewrite a rsqrt op"""

src/relay/transforms/fake_quantization_to_integer.cc

Lines changed: 251 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,13 @@
2525

2626
#include "fake_quantization_to_integer.h"
2727

28+
#include <tvm/ir/affine_type.h>
29+
#include <tvm/relay/attrs/nn.h>
30+
#include <tvm/relay/dataflow_matcher.h>
2831
#include <tvm/relay/expr.h>
2932
#include <tvm/relay/expr_functor.h>
3033
#include <tvm/relay/qnn/attrs.h>
34+
#include <tvm/relay/qnn/op/dequantize.h>
3135
#include <tvm/relay/transform.h>
3236

3337
#include <unordered_map>
@@ -37,7 +41,8 @@ namespace relay {
3741

3842
/* Description of FakeQuantizationToInteger
3943
*
40-
* The purpose of this pass is to find regions of the graph that follow
44+
* This pass consists of two parts, a basic one and an optional one.
45+
* The purpose of the basic part is to find regions of the graph that follow
4146
* the general pattern:
4247
*
4348
* x w
@@ -52,7 +57,7 @@ namespace relay {
5257
*
5358
* and convert them into subgraphs with actual integer operations on x and w
5459
*
55-
* The pass does this via a multi-pass approach:
60+
* The basic part does this via a multi-pass approach:
5661
*
5762
* The main pass is a MixedModeMutator that traverses the full graph searching for
5863
* quantize operations
@@ -69,6 +74,24 @@ namespace relay {
6974
*
7075
* After the second and third passes run, the first pass replaces the quantize with the
7176
* rewritten subgraph and the processing continues
77+
*
78+
* The main idea of the optional part is to find and transform operations with dequantized inputs
79+
* one by one individually. Only operations from the allowed list are allowed. For example, if on
80+
* the above general pattern op2 is not registered with the FTVMFakeQuantizationToInteger
81+
* attribute, op1 operation can still be converted. Converted pattern below:
82+
*
83+
* x w
84+
* | |
85+
* \ /
86+
* op1
87+
* |
88+
* dq
89+
* |
90+
* op2
91+
* |
92+
* q
93+
*
94+
* The optional part works in the same multi-pass approach.
7295
*/
7396

7497
using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>;
@@ -270,8 +293,233 @@ class FakeQuantizationRewriter : public MixedModeMutator {
270293
const bool hard_fail_;
271294
};
272295

296+
bool is_op_enabled_for_optional_fq2i(const CallNode* call_node) {
297+
const Op op = Downcast<Op>(call_node->op);
298+
static auto fqfq = Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger");
299+
static std::unordered_set<Op, tvm::ObjectHash, tvm::ObjectEqual> ops = {
300+
Op::Get("reshape"),
301+
Op::Get("squeeze"),
302+
Op::Get("strided_slice"),
303+
Op::Get("transpose"),
304+
Op::Get("expand_dims"),
305+
Op::Get("nn.max_pool2d"),
306+
Op::Get("nn.batch_flatten"),
307+
Op::Get("nn.depth_to_space"),
308+
Op::Get("max"),
309+
Op::Get("min"),
310+
Op::Get("nn.avg_pool2d"),
311+
Op::Get("nn.global_avg_pool2d"),
312+
Op::Get("nn.bias_add"),
313+
Op::Get("nn.conv2d"),
314+
Op::Get("nn.conv2d_transpose"),
315+
Op::Get("nn.dense"),
316+
Op::Get("nn.batch_matmul"),
317+
Op::Get("split"),
318+
Op::Get("clip"),
319+
Op::Get("nn.relu"),
320+
Op::Get("nn.pad"),
321+
Op::Get("broadcast_to"),
322+
Op::Get("minimum"),
323+
Op::Get("maximum")};
324+
325+
auto is_enabled = [&](const auto i) { return i == call_node->op; };
326+
auto result = std::find_if(std::begin(ops), std::end(ops), is_enabled);
327+
return result != ops.end() && fqfq.count(Downcast<Op>(op));
328+
}
329+
330+
class OptionalSubgraphExtractor : public ExprVisitor {
331+
public:
332+
const ExprSet GetSubgraph(const Expr& expr) {
333+
expr_call_node_ = expr.as<CallNode>();
334+
ICHECK(expr_call_node_ != nullptr);
335+
ICHECK(is_op_enabled_for_optional_fq2i(expr_call_node_));
336+
337+
VisitExpr(expr);
338+
339+
ExprSet subgraph;
340+
if (is_fake_quantized_) {
341+
for (auto kv : this->visit_counter_) {
342+
if (auto call_node = GetRef<ObjectRef>(kv.first).as<CallNode>()) {
343+
if (call_node != expr_call_node_) {
344+
subgraph.insert(Downcast<Expr>(GetRef<ObjectRef>(kv.first)));
345+
}
346+
}
347+
}
348+
}
349+
return subgraph;
350+
}
351+
const AffineTypeMap GetAffineTypes() { return affine_types_; }
352+
void VisitExpr(const Expr& expr) override {
353+
// When looking for fake quantized subgraphs, we only support data-flow regions of the graph,
354+
// i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we
355+
// abort the rewrite.
356+
if (expr.as<CallNode>() == nullptr && expr.as<OpNode>() == nullptr &&
357+
expr.as<TupleNode>() == nullptr && expr.as<TupleGetItemNode>() == nullptr &&
358+
expr.as<ConstantNode>() == nullptr) {
359+
DLOG(INFO) << "FakeQuantizationToInteger found a non - dataflow op inside a fake quantize "
360+
"region, aborting this rewrite";
361+
is_fake_quantized_ = false;
362+
} else {
363+
ExprVisitor::VisitExpr(expr);
364+
}
365+
}
366+
367+
protected:
368+
void VisitExpr_(const CallNode* call_node) override {
369+
if (call_node->op == dequantize_op_) {
370+
const auto* attrs = call_node->attrs.as<qnn::DequantizeAttrs>();
371+
ICHECK(attrs != nullptr);
372+
373+
affine_types_.Set(
374+
GetRef<Expr>(call_node),
375+
TensorAffineType(
376+
call_node->args[1], call_node->args[2],
377+
tvm::relay::transform::InferTypeLocal(call_node->args[0]).as<TensorTypeNode>()->dtype,
378+
attrs->axis));
379+
} else if (call_node == expr_call_node_) {
380+
for (auto arg : call_node->args) {
381+
VisitExpr(arg);
382+
}
383+
} else {
384+
// run normally on everything else.
385+
ExprVisitor::VisitExpr_(call_node);
386+
}
387+
}
388+
389+
const Op dequantize_op_ = Op::Get("qnn.dequantize");
390+
bool is_fake_quantized_ = true;
391+
AffineTypeMap affine_types_;
392+
const CallNode* expr_call_node_ = nullptr;
393+
};
394+
395+
class OptionalSubgraphMutator : public ExprMutator {
396+
public:
397+
OptionalSubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types, bool hard_fail)
398+
: subgraph_(subgraph), affine_types_(affine_types), hard_fail_(hard_fail) {}
399+
400+
Expr MutateSubgraph(const Expr& expr) {
401+
if (subgraph_.size() == 0) {
402+
return expr;
403+
}
404+
405+
quantize_node_ = expr.as<CallNode>();
406+
ICHECK(quantize_node_);
407+
ICHECK(is_op_enabled_for_optional_fq2i(quantize_node_));
408+
409+
for (auto node : subgraph_) {
410+
const Op op = Downcast<Op>(node.as<CallNode>()->op);
411+
412+
if (node.as<CallNode>()->op != dequantize_op_) {
413+
// Only modify the subgraph if we have translation
414+
// rules for every op
415+
if (hard_fail_) {
416+
LOG(FATAL) << "Found no rewrite rule for " << AsText(op, false) << std::endl;
417+
} else {
418+
DLOG(INFO) << "Found no rewrite rule for " << AsText(op, false) << std::endl;
419+
return expr;
420+
}
421+
}
422+
}
423+
try {
424+
return Mutate(expr);
425+
} catch (std::exception& e) {
426+
if (hard_fail_) {
427+
throw e;
428+
} else {
429+
DLOG(INFO) << "Ran into an error rewriting a subgraph, skipping" << expr << std::endl;
430+
return expr;
431+
}
432+
}
433+
}
434+
435+
protected:
436+
Expr VisitExpr_(const CallNode* call_node) {
437+
Expr out;
438+
static auto fqfq =
439+
Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger");
440+
441+
Op op = Downcast<Op>(call_node->op);
442+
if (fqfq.count(op)) {
443+
Expr expr;
444+
if (op == dequantize_op_) {
445+
expr = GetRef<Expr>(call_node);
446+
} else {
447+
expr = ExprMutator::VisitExpr_(call_node);
448+
}
449+
// Call the rewrite
450+
Array<ObjectRef> vals = fqfq[op](expr, affine_types_);
451+
// Save the outputs of the rewrite
452+
ICHECK(vals.size() == 2)
453+
<< "got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for "
454+
<< AsText(op, false);
455+
out = Downcast<Expr>(vals[0]);
456+
457+
affine_types_.Set(out, Downcast<AffineType>(vals[1]));
458+
459+
if (call_node == quantize_node_) {
460+
out = qnn::MakeDequantize(out, vals[1].as<TensorAffineTypeNode>()->scale,
461+
vals[1].as<TensorAffineTypeNode>()->zero_point,
462+
vals[1].as<TensorAffineTypeNode>()->axis);
463+
}
464+
} else {
465+
ICHECK(false) << "When rewriting a fake quantized graph, found an invalid node "
466+
<< AsText(GetRef<Expr>(call_node), false);
467+
}
468+
return out;
469+
}
470+
471+
Expr VisitExpr_(const TupleNode* node) {
472+
Expr expr = ExprMutator::VisitExpr_(node);
473+
auto new_node = expr.as<TupleNode>();
474+
Array<TensorAffineType> types;
475+
for (Expr field : new_node->fields) {
476+
ICHECK(affine_types_[field].as<TensorAffineTypeNode>());
477+
types.push_back(Downcast<TensorAffineType>(affine_types_[field]));
478+
}
479+
affine_types_.Set(expr, TupleAffineType(types));
480+
return expr;
481+
}
482+
483+
Expr VisitExpr_(const TupleGetItemNode* node) {
484+
Expr expr = ExprMutator::VisitExpr_(node);
485+
auto tuple_type = affine_types_[expr.as<TupleGetItemNode>()->tuple].as<TupleAffineTypeNode>();
486+
affine_types_.Set(expr, tuple_type->types[node->index]);
487+
return expr;
488+
}
489+
490+
ExprSet subgraph_;
491+
AffineTypeMap affine_types_;
492+
const bool hard_fail_;
493+
const Op dequantize_op_ = Op::Get("qnn.dequantize");
494+
const CallNode* quantize_node_ = nullptr;
495+
};
496+
497+
class OptionalFakeQuantizationRewriter : public MixedModeMutator {
498+
public:
499+
explicit OptionalFakeQuantizationRewriter(bool hard_fail) : hard_fail_(hard_fail) {}
500+
501+
protected:
502+
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
503+
if (const CallNode* call_node = post.as<CallNode>()) {
504+
const Op op = Downcast<Op>(call_node->op);
505+
if (is_op_enabled_for_optional_fq2i(call_node)) {
506+
OptionalSubgraphExtractor extractor;
507+
ExprSet subgraph = extractor.GetSubgraph(post);
508+
AffineTypeMap affine_types = extractor.GetAffineTypes();
509+
Expr out = OptionalSubgraphMutator(subgraph, affine_types, hard_fail_).MutateSubgraph(post);
510+
return out;
511+
}
512+
}
513+
return post;
514+
}
515+
const bool hard_fail_;
516+
};
517+
273518
Expr FakeQuantizationToInteger(const Expr& expr, const IRModule& mod, bool hard_fail) {
274-
return FakeQuantizationRewriter(hard_fail).Mutate(expr);
519+
auto fq_expr = FakeQuantizationRewriter(hard_fail).Mutate(expr);
520+
auto fq_inferred_expr = tvm::relay::InferType(fq_expr);
521+
auto ofq_expr = OptionalFakeQuantizationRewriter(hard_fail).Mutate(fq_inferred_expr);
522+
return ofq_expr;
275523
}
276524

277525
namespace transform {

src/relay/transforms/type_infer.cc

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include <tvm/ir/transform.h>
4242
#include <tvm/ir/type_functor.h>
4343
#include <tvm/relay/analysis.h>
44+
#include <tvm/relay/dataflow_matcher.h>
4445
#include <tvm/relay/expr_functor.h>
4546
#include <tvm/relay/pattern_functor.h>
4647
#include <tvm/relay/transform.h>
@@ -918,11 +919,7 @@ Type InferTypeLocal(const Expr& expr) {
918919
mod = transform::InferType()(mod);
919920

920921
Type result_type;
921-
if (expr.as<FunctionNode>()) {
922-
result_type = mod->Lookup("main")->checked_type();
923-
} else {
924-
result_type = mod->Lookup("main").as<FunctionNode>()->body->checked_type();
925-
}
922+
result_type = relay::InferType(sub_graph)->checked_type();
926923

927924
expr->checked_type_ = result_type;
928925
return result_type;

0 commit comments

Comments
 (0)