32
32
#include < tvm/relay/function.h>
33
33
#include < tvm/relay/op.h>
34
34
35
- #include < stack >
35
+ #include < deque >
36
36
#include < string>
37
37
#include < unordered_map>
38
38
#include < utility>
39
-
39
+ # include < vector >
40
40
namespace tvm {
41
41
namespace relay {
42
42
@@ -276,7 +276,9 @@ class MixedModeVisitor : public ::tvm::relay::ExprVisitor {
276
276
*/
277
277
class MixedModeMutator : public ::tvm::relay::ExprMutator {
278
278
public:
279
+ MixedModeMutator (bool pre = false ) : pre_{pre } {};
279
280
Expr VisitExpr (const Expr& expr) final ;
281
+
280
282
virtual Expr DispatchVisitExpr (const Expr& expr);
281
283
Expr VisitExpr_ (const TupleNode* op) final { return Rewrite (op); };
282
284
Expr VisitExpr_ (const CallNode* call_node) final { return Rewrite (call_node); };
@@ -294,6 +296,7 @@ class MixedModeMutator : public ::tvm::relay::ExprMutator {
294
296
virtual Expr Rewrite_ (const TupleGetItemNode* pre , const Expr& post ) { return post ; }
295
297
296
298
protected:
299
+ bool pre_;
297
300
/* ! \brief Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a `post` node with
298
301
* changed inputs.
299
302
*/
@@ -410,72 +413,82 @@ Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter);
410
413
*/
411
414
void PostOrderVisit (const Expr& node, std::function<void (const Expr&)> fvisit);
412
415
416
+ /* !
417
+ * \brief A struct to keep info of traversed expr in ExpandDataflow function
418
+ */
419
+ struct v_info {
420
+ explicit v_info (Expr node_) : node{node_} {}
421
+ v_info (Expr node_, bool children_expanded_)
422
+ : node{node_}, children_expanded{children_expanded_} {};
423
+ Expr node{};
424
+ bool children_expanded{false };
425
+ };
426
+
413
427
/* !
414
428
* \brief A function to iteratively traverse dataflow regions of a graph
415
429
*
416
430
* ExpandDataflow manually manages a stack and performs DFS to determine the processing
417
431
* order of nodes in an input graph.
418
432
*
419
- * If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the arguments to that node
420
- * need to be processed via fcheck_visited. If so, the function pushes those arguments to the stack
421
- * and continues iteratively to process the top of the stack. When it finds a node that doesn't
422
- * match the dataflow types, or a node who's inputs have all been processed, it visits the current
423
- * leaf via fvisit_leaf.
433
+ * By default fexpand_expr implemented in a way that if it finds a dataflow node (Call, Tuple,
434
+ * TupleGetItem), it checks if the arguments to that node need to be processed via fcheck_visited.
435
+ * If so, the function pushes those arguments to the stack and continues iteratively to process
436
+ * the top of the stack. When it finds a node that doesn't match the dataflow types, or a node who's
437
+ * inputs have all been processed, it visits the current leaf via fvisit_leaf.
424
438
*
425
439
* This function should be used internally to other classes to implement mixed-mode traversals. The
426
440
* expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal if it
427
441
* hits a non-dataflow node.
428
442
*
429
- * fcheck_visited and fvisit_leaf are templated to encourage compiler inlining .
443
+ * fcheck_visited, fvisit_leaf and fexpand_expr are templated to encourage reusing .
430
444
*/
431
- template <typename FCheckVisited, typename FVisitLeaf>
432
- void ExpandDataflow (Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
433
- std::stack<std::pair<Expr, bool >> stack;
445
+ template <typename FCheckVisited, typename FVisitLeaf, typename FExpandExpr>
446
+ void ExpandDataflow (Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf,
447
+ FExpandExpr fexpand_expr) {
448
+ std::deque<v_info> stack;
434
449
auto fpush_to_stack = [&fcheck_visited, &stack](const Expr& expr) {
435
- // The second state of the stack indicate whether the child has been
436
- // expanded in the pre-order.
437
- // NOTE: function will be inlined.
438
450
if (!fcheck_visited (expr)) {
439
- stack.push ({ expr, false } );
451
+ stack.emplace_front ( v_info ( expr) );
440
452
}
441
453
};
454
+
442
455
fpush_to_stack (expr);
443
456
while (stack.size () > 0 ) {
444
- auto node = stack.top ().first ;
445
- if (fcheck_visited (node)) {
446
- // if this node was visited through another path
447
- // after being added to the stack ignore it.
448
- stack.pop ();
449
- } else if (stack.top ().second ) {
450
- // all the children have already been expanded.
451
- // we can just run post order visit on it.
452
- fvisit_leaf (node);
453
- stack.pop ();
454
- } else if (const CallNode* op = node.as <CallNode>()) {
455
- // mark expanded = true
456
- stack.top ().second = true ;
457
- // push the children to the stack in reverse order
458
- // to match recursive processing order
457
+ v_info* front = &stack.front ();
458
+ if (fcheck_visited (front->node )) {
459
+ stack.pop_front ();
460
+ } else if (front->children_expanded ) {
461
+ fvisit_leaf (front->node );
462
+ // TODO(d-smirnov): this is for compatibility with current implementation of MixedModeVisitor
463
+ stack.pop_front ();
464
+ } else {
465
+ front->children_expanded = true ;
466
+ for (auto e : fexpand_expr (front->node )) {
467
+ fpush_to_stack (e);
468
+ }
469
+ }
470
+ }
471
+ }
472
+
473
+ template <typename FCheckVisited, typename FVisitLeaf>
474
+ void ExpandDataflow (Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
475
+ auto fexpand_expr = [](const Expr& expr) {
476
+ std::vector<Expr> result;
477
+ if (const CallNode* op = expr.as <CallNode>()) {
459
478
for (auto it = op->args .rbegin (); it != op->args .rend (); ++it) {
460
- fpush_to_stack (*it);
479
+ result. push_back (*it);
461
480
}
462
- fpush_to_stack (op->op );
463
- } else if (const TupleNode* op = node.as <TupleNode>()) {
464
- stack.top ().second = true ;
465
- // push the children to the stack in reverse order
466
- // to match recursive processing order
481
+ result.push_back (op->op );
482
+ } else if (const TupleNode* op = expr.as <TupleNode>()) {
467
483
for (auto it = op->fields .rbegin (); it != op->fields .rend (); ++it) {
468
- fpush_to_stack (*it);
484
+ result. push_back (*it);
469
485
}
470
- } else if (const TupleGetItemNode* op = node.as <TupleGetItemNode>()) {
471
- stack.top ().second = true ;
472
- fpush_to_stack (op->tuple );
473
- } else {
474
- // No need to expand the children directly run visit.
475
- fvisit_leaf (node);
476
- stack.pop ();
486
+ } else if (const TupleGetItemNode* op = expr.as <TupleGetItemNode>()) {
487
+ result.push_back (op->tuple );
477
488
}
478
- }
489
+ return result;
490
+ };
491
+ ExpandDataflow (expr, fcheck_visited, fvisit_leaf, fexpand_expr);
479
492
}
480
493
481
494
void ExpandANormalForm (const LetNode* op, std::function<void (const LetNode*)> pre_visit,
0 commit comments