Skip to content

Commit 4467a9c

Browse files
authored
RelayTextPrinter is now non-recursive. ExpandDataflow refactored (#7817)
* RelayTextPrinter is now non-recursive. ExpandDataflow refactored RelayTextPrinter is now non-recursive to allow printing larger graphs. ExpandDataflow is generalised to have separate node expander. Change-Id: Id5a3a470fbc8b90822502fbc8d24d534df1ea355 * requested changes Change-Id: Iac69766428d5b9783279cb02a57064fd82842001 * unit test added Change-Id: Id20ae72f9f5f8dd92d4d182360b28156c035e667
1 parent 0b24cbf commit 4467a9c

File tree

5 files changed

+169
-53
lines changed

5 files changed

+169
-53
lines changed

include/tvm/relay/expr_functor.h

Lines changed: 58 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@
3232
#include <tvm/relay/function.h>
3333
#include <tvm/relay/op.h>
3434

35-
#include <stack>
35+
#include <deque>
3636
#include <string>
3737
#include <unordered_map>
3838
#include <utility>
39-
39+
#include <vector>
4040
namespace tvm {
4141
namespace relay {
4242

@@ -276,7 +276,9 @@ class MixedModeVisitor : public ::tvm::relay::ExprVisitor {
276276
*/
277277
class MixedModeMutator : public ::tvm::relay::ExprMutator {
278278
public:
279+
MixedModeMutator(bool pre = false) : pre_{pre} {};
279280
Expr VisitExpr(const Expr& expr) final;
281+
280282
virtual Expr DispatchVisitExpr(const Expr& expr);
281283
Expr VisitExpr_(const TupleNode* op) final { return Rewrite(op); };
282284
Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); };
@@ -294,6 +296,7 @@ class MixedModeMutator : public ::tvm::relay::ExprMutator {
294296
virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { return post; }
295297

296298
protected:
299+
bool pre_;
297300
/*! \brief Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a `post` node with
298301
* changed inputs.
299302
*/
@@ -410,72 +413,82 @@ Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter);
410413
*/
411414
void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
412415

416+
/*!
417+
* \brief A struct to keep info of traversed expr in ExpandDataflow function
418+
*/
419+
struct v_info {
420+
explicit v_info(Expr node_) : node{node_} {}
421+
v_info(Expr node_, bool children_expanded_)
422+
: node{node_}, children_expanded{children_expanded_} {};
423+
Expr node{};
424+
bool children_expanded{false};
425+
};
426+
413427
/*!
414428
* \brief A function to iteratively traverse dataflow regions of a graph
415429
*
416430
* ExpandDataflow manually manages a stack and performs DFS to determine the processing
417431
* order of nodes in an input graph.
418432
*
419-
* If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the arguments to that node
420-
* need to be processed via fcheck_visited. If so, the function pushes those arguments to the stack
421-
* and continues iteratively to process the top of the stack. When it finds a node that doesn't
422-
* match the dataflow types, or a node who's inputs have all been processed, it visits the current
423-
* leaf via fvisit_leaf.
433+
* By default fexpand_expr implemented in a way that if it finds a dataflow node (Call, Tuple,
434+
* TupleGetItem), it checks if the arguments to that node need to be processed via fcheck_visited.
435+
* If so, the function pushes those arguments to the stack and continues iteratively to process
436+
* the top of the stack. When it finds a node that doesn't match the dataflow types, or a node who's
437+
* inputs have all been processed, it visits the current leaf via fvisit_leaf.
424438
*
425439
* This function should be used internally to other classes to implement mixed-mode traversals. The
426440
* expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal if it
427441
* hits a non-dataflow node.
428442
*
429-
* fcheck_visited and fvisit_leaf are templated to encourage compiler inlining.
443+
* fcheck_visited, fvisit_leaf and fexpand_expr are templated to encourage reusing.
430444
*/
431-
template <typename FCheckVisited, typename FVisitLeaf>
432-
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
433-
std::stack<std::pair<Expr, bool>> stack;
445+
template <typename FCheckVisited, typename FVisitLeaf, typename FExpandExpr>
446+
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf,
447+
FExpandExpr fexpand_expr) {
448+
std::deque<v_info> stack;
434449
auto fpush_to_stack = [&fcheck_visited, &stack](const Expr& expr) {
435-
// The second state of the stack indicate whether the child has been
436-
// expanded in the pre-order.
437-
// NOTE: function will be inlined.
438450
if (!fcheck_visited(expr)) {
439-
stack.push({expr, false});
451+
stack.emplace_front(v_info(expr));
440452
}
441453
};
454+
442455
fpush_to_stack(expr);
443456
while (stack.size() > 0) {
444-
auto node = stack.top().first;
445-
if (fcheck_visited(node)) {
446-
// if this node was visited through another path
447-
// after being added to the stack ignore it.
448-
stack.pop();
449-
} else if (stack.top().second) {
450-
// all the children have already been expanded.
451-
// we can just run post order visit on it.
452-
fvisit_leaf(node);
453-
stack.pop();
454-
} else if (const CallNode* op = node.as<CallNode>()) {
455-
// mark expanded = true
456-
stack.top().second = true;
457-
// push the children to the stack in reverse order
458-
// to match recursive processing order
457+
v_info* front = &stack.front();
458+
if (fcheck_visited(front->node)) {
459+
stack.pop_front();
460+
} else if (front->children_expanded) {
461+
fvisit_leaf(front->node);
462+
// TODO(d-smirnov): this is for compatibility with current implementation of MixedModeVisitor
463+
stack.pop_front();
464+
} else {
465+
front->children_expanded = true;
466+
for (auto e : fexpand_expr(front->node)) {
467+
fpush_to_stack(e);
468+
}
469+
}
470+
}
471+
}
472+
473+
template <typename FCheckVisited, typename FVisitLeaf>
474+
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
475+
auto fexpand_expr = [](const Expr& expr) {
476+
std::vector<Expr> result;
477+
if (const CallNode* op = expr.as<CallNode>()) {
459478
for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
460-
fpush_to_stack(*it);
479+
result.push_back(*it);
461480
}
462-
fpush_to_stack(op->op);
463-
} else if (const TupleNode* op = node.as<TupleNode>()) {
464-
stack.top().second = true;
465-
// push the children to the stack in reverse order
466-
// to match recursive processing order
481+
result.push_back(op->op);
482+
} else if (const TupleNode* op = expr.as<TupleNode>()) {
467483
for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
468-
fpush_to_stack(*it);
484+
result.push_back(*it);
469485
}
470-
} else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) {
471-
stack.top().second = true;
472-
fpush_to_stack(op->tuple);
473-
} else {
474-
// No need to expand the children directly run visit.
475-
fvisit_leaf(node);
476-
stack.pop();
486+
} else if (const TupleGetItemNode* op = expr.as<TupleGetItemNode>()) {
487+
result.push_back(op->tuple);
477488
}
478-
}
489+
return result;
490+
};
491+
ExpandDataflow(expr, fcheck_visited, fvisit_leaf, fexpand_expr);
479492
}
480493

481494
void ExpandANormalForm(const LetNode* op, std::function<void(const LetNode*)> pre_visit,

src/printer/relay_text_printer.cc

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,34 @@ bool RelayTextPrinter::AlwaysInline(const Expr& expr) {
236236
expr.as<VarNode>() || expr.as<ConstructorNode>();
237237
}
238238

239+
Doc RelayTextPrinter::VisitLeaf(const Expr& expr) {
240+
if (!CheckVisited(expr)) {
241+
Doc result = ExprFunctor<Doc(const Expr&)>::VisitExpr(expr);
242+
// Add if not added after visiting
243+
if (!CheckVisited(expr)) {
244+
memo_[expr] = result;
245+
} else {
246+
result_memo_[expr] = result;
247+
}
248+
return result;
249+
}
250+
return memo_[expr];
251+
}
252+
253+
bool RelayTextPrinter::CheckVisited(const Expr& expr) { return (memo_.count(expr)); }
254+
255+
Doc RelayTextPrinter::VisitExpr(const Expr& expr) {
256+
auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); };
257+
auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); };
258+
259+
if (fcheck_visited(expr)) {
260+
return memo_[expr];
261+
} else {
262+
ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
263+
return memo_[expr];
264+
}
265+
}
266+
239267
//------------------------------------
240268
// Overload of Expr printing functions
241269
//------------------------------------
@@ -252,9 +280,6 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, bo
252280
inline_expr |= IsUnique(expr);
253281
}
254282

255-
auto it = memo_.find(expr);
256-
if (it != memo_.end()) return it->second;
257-
258283
Doc printed_expr;
259284

260285
if (meta) {
@@ -277,13 +302,19 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, bo
277302
if (expr.as<VarNode>()) {
278303
// This is our first time visiting the var and we hit the VarNode case
279304
// in the visitor. Thus the variable is free.
280-
doc_stack_.back() << "free_var " << printed_expr << ";" << Doc::NewLine();
305+
if (var_memo_.insert(expr).second && result_memo_.count(expr)) {
306+
doc_stack_.back() << "free_var " << result_memo_[expr] << ";" << Doc::NewLine();
307+
}
281308
// Memoization is done in AllocVar.
282309
return memo_[expr];
283310
} else if (inline_expr) {
284311
memo_[expr] = printed_expr;
285312
return printed_expr;
286313
} else {
314+
// Already exists. Reuse
315+
if (!var_memo_.insert(expr).second) {
316+
return memo_[expr];
317+
}
287318
Doc temp_var = AllocTemp();
288319
memo_[expr] = temp_var;
289320
doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine();

src/printer/text_printer.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
#include <string>
3939
#include <unordered_map>
40+
#include <unordered_set>
4041
#include <vector>
4142

4243
#include "../ir/attr_functor.h"
@@ -60,6 +61,9 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
6061
explicit RelayTextPrinter(bool show_meta_data, TextMetaDataContext* meta,
6162
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate)
6263
: show_meta_data_(show_meta_data), annotate_(annotate), meta_(meta) {}
64+
Doc VisitExpr(const Expr& expr) override;
65+
virtual Doc VisitLeaf(const Expr& expr);
66+
virtual bool CheckVisited(const Expr& expr);
6367

6468
/*!
6569
* \brief Print additional info about expr in comment.
@@ -145,7 +149,7 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
145149
Doc PrintType(const Type& type, bool meta);
146150
Doc VisitTypeDefault_(const Object* node) final;
147151
Doc VisitType_(const TypeVarNode* node) final;
148-
Doc VisitType_(const GlobalTypeVarNode* node);
152+
Doc VisitType_(const GlobalTypeVarNode* node) final;
149153
Doc VisitType_(const TypeCallNode* node) final;
150154
Doc PrintDType(DataType dtype);
151155
Doc VisitType_(const TensorTypeNode* node) final;
@@ -170,6 +174,10 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
170174
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_;
171175
/*! \brief Stack of docs to implement scoped GNFing. */
172176
std::vector<Doc> doc_stack_{};
177+
/*! \brief Set for introduced vars */
178+
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> var_memo_;
179+
/*! \brief Map for result and memo_ diffs for visited expression */
180+
std::unordered_map<Expr, Doc, ObjectPtrHash, ObjectPtrEqual> result_memo_;
173181
/*! \brief Map from Expr to Doc */
174182
std::unordered_map<Expr, Doc, ObjectPtrHash, ObjectPtrEqual> memo_;
175183
/*! \brief Map from Type to Doc */

src/relay/analysis/dependency_graph.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace tvm {
3232
namespace relay {
3333

3434
// Creator of DependencyGraph
35-
class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
35+
class DependencyGraph::Creator : private MixedModeVisitor {
3636
public:
3737
explicit Creator(support::Arena* arena) : arena_(arena) {}
3838

@@ -73,13 +73,13 @@ class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
7373
return ret;
7474
}
7575

76-
void VisitExpr(const Expr& e) final {
76+
void VisitLeaf(const Expr& e) override {
7777
if (visited_.count(e) == 0) {
7878
if (graph_.expr_node.count(e) == 0) {
7979
graph_.expr_node[e] = NewNode(false);
8080
}
8181
visited_.insert(e);
82-
ExprFunctor<void(const Expr&)>::VisitExpr(e);
82+
MixedModeVisitor::VisitLeaf(e);
8383
graph_.post_dfs_order.push_back(graph_.expr_node[e]);
8484
}
8585
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include <gtest/gtest.h>
21+
#include <tvm/ir/expr.h>
22+
#include <tvm/ir/type_functor.h>
23+
#include <tvm/node/functor.h>
24+
#include <tvm/node/structural_equal.h>
25+
#include <tvm/relay/adt.h>
26+
#include <tvm/relay/analysis.h>
27+
#include <tvm/relay/expr.h>
28+
#include <tvm/relay/expr_functor.h>
29+
#include <tvm/relay/function.h>
30+
#include <tvm/relay/op.h>
31+
#include <tvm/relay/op_attr_types.h>
32+
#include <tvm/relay/op_strategy.h>
33+
#include <tvm/relay/transform.h>
34+
#include <tvm/relay/type.h>
35+
#include <tvm/runtime/packed_func.h>
36+
#include <tvm/runtime/registry.h>
37+
#include <tvm/te/operation.h>
38+
#include <tvm/topi/broadcast.h>
39+
#include <tvm/topi/generic/injective.h>
40+
41+
using namespace tvm;
42+
using namespace tvm::relay;
43+
44+
TEST(Relay, LargeGraphPrint) {
45+
auto foo = [] {
46+
auto add_op = relay::Op::Get("add");
47+
auto c_data = tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
48+
auto c1 = relay::Constant(c_data);
49+
Call y1 = relay::Call(add_op, {c1, c1});
50+
for (int i = 0; i < 1e6; i++) {
51+
y1 = relay::Call(add_op, {c1, y1});
52+
}
53+
relay::Function func = relay::Function({}, y1, relay::Type(), {});
54+
std::string result = AsText(func);
55+
ASSERT_GT(0, result.size());
56+
};
57+
ASSERT_EXIT((foo(), exit(0)), ::testing::ExitedWithCode(0), ".*");
58+
}
59+
60+
int main(int argc, char** argv) {
61+
testing::InitGoogleTest(&argc, argv);
62+
testing::FLAGS_gtest_death_test_style = "threadsafe";
63+
return RUN_ALL_TESTS();
64+
}

0 commit comments

Comments
 (0)