Skip to content

Commit 8c5f8e4

Browse files
committed
RelayTextPrinter is now non-recursive. ExpandDataflow refactored
RelayTextPrinter is now non-recursive to allow printing larger graphs. ExpandDataflow is generalised to have separate node expander. Change-Id: Id5a3a470fbc8b90822502fbc8d24d534df1ea355
1 parent 461d06e commit 8c5f8e4

File tree

4 files changed

+105
-52
lines changed

4 files changed

+105
-52
lines changed

include/tvm/relay/expr_functor.h

Lines changed: 58 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@
3232
#include <tvm/relay/function.h>
3333
#include <tvm/relay/op.h>
3434

35+
#include <deque>
3536
#include <stack>
3637
#include <string>
3738
#include <unordered_map>
3839
#include <utility>
39-
40+
#include <vector>
4041
namespace tvm {
4142
namespace relay {
4243

@@ -276,7 +277,9 @@ class MixedModeVisitor : public ::tvm::relay::ExprVisitor {
276277
*/
277278
class MixedModeMutator : public ::tvm::relay::ExprMutator {
278279
public:
280+
MixedModeMutator(bool pre = false) : pre_{pre} {};
279281
Expr VisitExpr(const Expr& expr) final;
282+
280283
virtual Expr DispatchVisitExpr(const Expr& expr);
281284
Expr VisitExpr_(const TupleNode* op) final { return Rewrite(op); };
282285
Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); };
@@ -294,6 +297,7 @@ class MixedModeMutator : public ::tvm::relay::ExprMutator {
294297
virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { return post; }
295298

296299
protected:
300+
bool pre_;
297301
/*! \brief Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a `post` node with
298302
* changed inputs.
299303
*/
@@ -410,72 +414,82 @@ Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter);
410414
*/
411415
void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
412416

417+
/*!
418+
* \brief A struct to keep info of traversed expr in ExpandDataflow function
419+
*/
420+
struct v_info {
421+
explicit v_info(Expr node_) : node{node_} {}
422+
v_info(Expr node_, bool children_expanded_)
423+
: node{node_}, children_expanded{children_expanded_} {};
424+
Expr node{};
425+
bool children_expanded{false};
426+
};
427+
413428
/*!
414429
* \brief A function to iteratively traverse dataflow regions of a graph
415430
*
416431
* ExpandDataflow manually manages a stack and performs DFS to determine the processing
417432
* order of nodes in an input graph.
418433
*
419-
* If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the arguments to that node
420-
* need to be processed via fcheck_visited. If so, the function pushes those arguments to the stack
421-
* and continues iteratively to process the top of the stack. When it finds a node that doesn't
422-
* match the dataflow types, or a node who's inputs have all been processed, it visits the current
423-
* leaf via fvisit_leaf.
434+
* By default fexpand_expr implemented in a way that if it finds a dataflow node (Call, Tuple,
435+
* TupleGetItem), it checks if the arguments to that node need to be processed via fcheck_visited.
436+
* If so, the function pushes those arguments to the stack and continues iteratively to process
437+
* the top of the stack. When it finds a node that doesn't match the dataflow types, or a node who's
438+
* inputs have all been processed, it visits the current leaf via fvisit_leaf.
424439
*
425440
* This function should be used internally to other classes to implement mixed-mode traversals. The
426441
* expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal if it
427442
* hits a non-dataflow node.
428443
*
429-
* fcheck_visited and fvisit_leaf are templated to encourage compiler inlining.
444+
* fcheck_visited, fvisit_leaf and fexpand_expr are templated to encourage reusing.
430445
*/
431-
template <typename FCheckVisited, typename FVisitLeaf>
432-
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
433-
std::stack<std::pair<Expr, bool>> stack;
446+
template <typename FCheckVisited, typename FVisitLeaf, typename FExpandExpr>
447+
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf,
448+
FExpandExpr fexpand_expr) {
449+
std::deque<v_info> stack;
434450
auto fpush_to_stack = [&fcheck_visited, &stack](const Expr& expr) {
435-
// The second state of the stack indicate whether the child has been
436-
// expanded in the pre-order.
437-
// NOTE: function will be inlined.
438451
if (!fcheck_visited(expr)) {
439-
stack.push({expr, false});
452+
stack.push_front(std::move(v_info(expr)));
440453
}
441454
};
455+
442456
fpush_to_stack(expr);
443457
while (stack.size() > 0) {
444-
auto node = stack.top().first;
445-
if (fcheck_visited(node)) {
446-
// if this node was visited through another path
447-
// after being added to the stack ignore it.
448-
stack.pop();
449-
} else if (stack.top().second) {
450-
// all the children have already been expanded.
451-
// we can just run post order visit on it.
452-
fvisit_leaf(node);
453-
stack.pop();
454-
} else if (const CallNode* op = node.as<CallNode>()) {
455-
// mark expanded = true
456-
stack.top().second = true;
457-
// push the children to the stack in reverse order
458-
// to match recursive processing order
458+
v_info* front = &stack.front();
459+
if (fcheck_visited(front->node)) {
460+
stack.pop_front();
461+
} else if (front->children_expanded) {
462+
fvisit_leaf(front->node);
463+
// TODO(d-smirnov): this is for compatibility with current implementation of MixedModeVisitor
464+
stack.pop_front();
465+
} else {
466+
front->children_expanded = true;
467+
for (auto e : fexpand_expr(front->node)) {
468+
fpush_to_stack(e);
469+
}
470+
}
471+
}
472+
}
473+
474+
template <typename FCheckVisited, typename FVisitLeaf>
475+
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
476+
auto fexpand_expr = [](const Expr& expr) {
477+
std::vector<Expr> result;
478+
if (const CallNode* op = expr.as<CallNode>()) {
459479
for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
460-
fpush_to_stack(*it);
480+
result.push_back(*it);
461481
}
462-
fpush_to_stack(op->op);
463-
} else if (const TupleNode* op = node.as<TupleNode>()) {
464-
stack.top().second = true;
465-
// push the children to the stack in reverse order
466-
// to match recursive processing order
482+
result.push_back(op->op);
483+
} else if (const TupleNode* op = expr.as<TupleNode>()) {
467484
for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
468-
fpush_to_stack(*it);
485+
result.push_back(*it);
469486
}
470-
} else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) {
471-
stack.top().second = true;
472-
fpush_to_stack(op->tuple);
473-
} else {
474-
// No need to expand the children directly run visit.
475-
fvisit_leaf(node);
476-
stack.pop();
487+
} else if (const TupleGetItemNode* op = expr.as<TupleGetItemNode>()) {
488+
result.push_back(op->tuple);
477489
}
478-
}
490+
return std::move(result);
491+
};
492+
ExpandDataflow(expr, fcheck_visited, fvisit_leaf, fexpand_expr);
479493
}
480494

481495
void ExpandANormalForm(const LetNode* op, std::function<void(const LetNode*)> pre_visit,

src/printer/relay_text_printer.cc

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,34 @@ bool RelayTextPrinter::AlwaysInline(const Expr& expr) {
236236
expr.as<VarNode>() || expr.as<ConstructorNode>();
237237
}
238238

239+
Doc RelayTextPrinter::VisitLeaf(const Expr& expr) {
240+
if (!CheckVisited(expr)) {
241+
Doc result = ExprFunctor<Doc(const Expr&)>::VisitExpr(expr);
242+
// Add if not added after visiting
243+
if (!CheckVisited(expr)) {
244+
memo_[expr] = result;
245+
} else {
246+
result_memo_[expr] = result;
247+
}
248+
return result;
249+
}
250+
return memo_[expr];
251+
}
252+
253+
bool RelayTextPrinter::CheckVisited(const Expr& expr) { return (memo_.count(expr)); }
254+
255+
Doc RelayTextPrinter::VisitExpr(const Expr& expr) {
256+
auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); };
257+
auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); };
258+
259+
if (fcheck_visited(expr)) {
260+
return memo_[expr];
261+
} else {
262+
ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
263+
return memo_[expr];
264+
}
265+
}
266+
239267
//------------------------------------
240268
// Overload of Expr printing functions
241269
//------------------------------------
@@ -252,9 +280,6 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, bo
252280
inline_expr |= IsUnique(expr);
253281
}
254282

255-
auto it = memo_.find(expr);
256-
if (it != memo_.end()) return it->second;
257-
258283
Doc printed_expr;
259284

260285
if (meta) {
@@ -277,13 +302,19 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, bo
277302
if (expr.as<VarNode>()) {
278303
// This is our first time visiting the var and we hit the VarNode case
279304
// in the visitor. Thus the variable is free.
280-
doc_stack_.back() << "free_var " << printed_expr << ";" << Doc::NewLine();
305+
if (var_memo_.insert(expr).second && result_memo_.count(expr)) {
306+
doc_stack_.back() << "free_var " << result_memo_[expr] << ";" << Doc::NewLine();
307+
}
281308
// Memoization is done in AllocVar.
282309
return memo_[expr];
283310
} else if (inline_expr) {
284311
memo_[expr] = printed_expr;
285312
return printed_expr;
286313
} else {
314+
// Already exists. Reuse
315+
if (!var_memo_.insert(expr).second) {
316+
return memo_[expr];
317+
}
287318
Doc temp_var = AllocTemp();
288319
memo_[expr] = temp_var;
289320
doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine();

src/printer/text_printer.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
#include <string>
3939
#include <unordered_map>
40+
#include <unordered_set>
4041
#include <vector>
4142

4243
#include "../ir/attr_functor.h"
@@ -60,6 +61,9 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
6061
explicit RelayTextPrinter(bool show_meta_data, TextMetaDataContext* meta,
6162
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate)
6263
: show_meta_data_(show_meta_data), annotate_(annotate), meta_(meta) {}
64+
Doc VisitExpr(const Expr& expr) override;
65+
virtual Doc VisitLeaf(const Expr& expr);
66+
virtual bool CheckVisited(const Expr& expr);
6367

6468
/*!
6569
* \brief Print additional info about expr in comment.
@@ -145,7 +149,7 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
145149
Doc PrintType(const Type& type, bool meta);
146150
Doc VisitTypeDefault_(const Object* node) final;
147151
Doc VisitType_(const TypeVarNode* node) final;
148-
Doc VisitType_(const GlobalTypeVarNode* node);
152+
Doc VisitType_(const GlobalTypeVarNode* node) final;
149153
Doc VisitType_(const TypeCallNode* node) final;
150154
Doc PrintDType(DataType dtype);
151155
Doc VisitType_(const TensorTypeNode* node) final;
@@ -170,6 +174,10 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
170174
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_;
171175
/*! \brief Stack of docs to implement scoped GNFing. */
172176
std::vector<Doc> doc_stack_{};
177+
/*! \brief Set for introduced vars */
178+
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> var_memo_;
179+
/*! \brief Map for result and memo_ diffs for visited expression */
180+
std::unordered_map<Expr, Doc, ObjectPtrHash, ObjectPtrEqual> result_memo_;
173181
/*! \brief Map from Expr to Doc */
174182
std::unordered_map<Expr, Doc, ObjectPtrHash, ObjectPtrEqual> memo_;
175183
/*! \brief Map from Type to Doc */

src/relay/analysis/dependency_graph.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace tvm {
3232
namespace relay {
3333

3434
// Creator of DependencyGraph
35-
class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
35+
class DependencyGraph::Creator : private MixedModeVisitor {
3636
public:
3737
explicit Creator(support::Arena* arena) : arena_(arena) {}
3838

@@ -73,13 +73,13 @@ class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
7373
return ret;
7474
}
7575

76-
void VisitExpr(const Expr& e) final {
76+
void VisitLeaf(const Expr& e) override {
7777
if (visited_.count(e) == 0) {
7878
if (graph_.expr_node.count(e) == 0) {
7979
graph_.expr_node[e] = NewNode(false);
8080
}
8181
visited_.insert(e);
82-
ExprFunctor<void(const Expr&)>::VisitExpr(e);
82+
MixedModeVisitor::VisitLeaf(e);
8383
graph_.post_dfs_order.push_back(graph_.expr_node[e]);
8484
}
8585
}

0 commit comments

Comments
 (0)