32
32
#include < tvm/relay/function.h>
33
33
#include < tvm/relay/op.h>
34
34
35
+ #include < deque>
35
36
#include < stack>
36
37
#include < string>
37
38
#include < unordered_map>
38
39
#include < utility>
39
-
40
+ # include < vector >
40
41
namespace tvm {
41
42
namespace relay {
42
43
@@ -276,7 +277,9 @@ class MixedModeVisitor : public ::tvm::relay::ExprVisitor {
276
277
*/
277
278
class MixedModeMutator : public ::tvm::relay::ExprMutator {
278
279
public:
280
+ MixedModeMutator (bool pre = false ) : pre_{pre } {};
279
281
Expr VisitExpr (const Expr& expr) final ;
282
+
280
283
virtual Expr DispatchVisitExpr (const Expr& expr);
281
284
Expr VisitExpr_ (const TupleNode* op) final { return Rewrite (op); };
282
285
Expr VisitExpr_ (const CallNode* call_node) final { return Rewrite (call_node); };
@@ -294,6 +297,7 @@ class MixedModeMutator : public ::tvm::relay::ExprMutator {
294
297
virtual Expr Rewrite_ (const TupleGetItemNode* pre , const Expr& post ) { return post ; }
295
298
296
299
protected:
300
+ bool pre_;
297
301
/* ! \brief Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a `post` node with
298
302
* changed inputs.
299
303
*/
@@ -410,72 +414,82 @@ Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter);
410
414
*/
411
415
void PostOrderVisit (const Expr& node, std::function<void (const Expr&)> fvisit);
412
416
417
+ /* !
418
+ * \brief A struct to keep info of traversed expr in ExpandDataflow function
419
+ */
420
+ struct v_info {
421
+ explicit v_info (Expr node_) : node{node_} {}
422
+ v_info (Expr node_, bool children_expanded_)
423
+ : node{node_}, children_expanded{children_expanded_} {};
424
+ Expr node{};
425
+ bool children_expanded{false };
426
+ };
427
+
413
428
/* !
414
429
* \brief A function to iteratively traverse dataflow regions of a graph
415
430
*
416
431
* ExpandDataflow manually manages a stack and performs DFS to determine the processing
417
432
* order of nodes in an input graph.
418
433
*
419
- * If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the arguments to that node
420
- * need to be processed via fcheck_visited. If so, the function pushes those arguments to the stack
421
- * and continues iteratively to process the top of the stack. When it finds a node that doesn't
422
- * match the dataflow types, or a node who's inputs have all been processed, it visits the current
423
- * leaf via fvisit_leaf.
434
+ * By default fexpand_expr implemented in a way that if it finds a dataflow node (Call, Tuple,
435
+ * TupleGetItem), it checks if the arguments to that node need to be processed via fcheck_visited.
436
+ * If so, the function pushes those arguments to the stack and continues iteratively to process
437
+ * the top of the stack. When it finds a node that doesn't match the dataflow types, or a node who's
438
+ * inputs have all been processed, it visits the current leaf via fvisit_leaf.
424
439
*
425
440
* This function should be used internally to other classes to implement mixed-mode traversals. The
426
441
* expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal if it
427
442
* hits a non-dataflow node.
428
443
*
429
- * fcheck_visited and fvisit_leaf are templated to encourage compiler inlining .
444
+ * fcheck_visited, fvisit_leaf and fexpand_expr are templated to encourage reusing .
430
445
*/
431
- template <typename FCheckVisited, typename FVisitLeaf>
432
- void ExpandDataflow (Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
433
- std::stack<std::pair<Expr, bool >> stack;
446
+ template <typename FCheckVisited, typename FVisitLeaf, typename FExpandExpr>
447
+ void ExpandDataflow (Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf,
448
+ FExpandExpr fexpand_expr) {
449
+ std::deque<v_info> stack;
434
450
auto fpush_to_stack = [&fcheck_visited, &stack](const Expr& expr) {
435
- // The second state of the stack indicate whether the child has been
436
- // expanded in the pre-order.
437
- // NOTE: function will be inlined.
438
451
if (!fcheck_visited (expr)) {
439
- stack.push ({ expr, false } );
452
+ stack.push_front ( std::move ( v_info ( expr)) );
440
453
}
441
454
};
455
+
442
456
fpush_to_stack (expr);
443
457
while (stack.size () > 0 ) {
444
- auto node = stack.top ().first ;
445
- if (fcheck_visited (node)) {
446
- // if this node was visited through another path
447
- // after being added to the stack ignore it.
448
- stack.pop ();
449
- } else if (stack.top ().second ) {
450
- // all the children have already been expanded.
451
- // we can just run post order visit on it.
452
- fvisit_leaf (node);
453
- stack.pop ();
454
- } else if (const CallNode* op = node.as <CallNode>()) {
455
- // mark expanded = true
456
- stack.top ().second = true ;
457
- // push the children to the stack in reverse order
458
- // to match recursive processing order
458
+ v_info* front = &stack.front ();
459
+ if (fcheck_visited (front->node )) {
460
+ stack.pop_front ();
461
+ } else if (front->children_expanded ) {
462
+ fvisit_leaf (front->node );
463
+ // TODO(d-smirnov): this is for compatibility with current implementation of MixedModeVisitor
464
+ stack.pop_front ();
465
+ } else {
466
+ front->children_expanded = true ;
467
+ for (auto e : fexpand_expr (front->node )) {
468
+ fpush_to_stack (e);
469
+ }
470
+ }
471
+ }
472
+ }
473
+
474
+ template <typename FCheckVisited, typename FVisitLeaf>
475
+ void ExpandDataflow (Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
476
+ auto fexpand_expr = [](const Expr& expr) {
477
+ std::vector<Expr> result;
478
+ if (const CallNode* op = expr.as <CallNode>()) {
459
479
for (auto it = op->args .rbegin (); it != op->args .rend (); ++it) {
460
- fpush_to_stack (*it);
480
+ result. push_back (*it);
461
481
}
462
- fpush_to_stack (op->op );
463
- } else if (const TupleNode* op = node.as <TupleNode>()) {
464
- stack.top ().second = true ;
465
- // push the children to the stack in reverse order
466
- // to match recursive processing order
482
+ result.push_back (op->op );
483
+ } else if (const TupleNode* op = expr.as <TupleNode>()) {
467
484
for (auto it = op->fields .rbegin (); it != op->fields .rend (); ++it) {
468
- fpush_to_stack (*it);
485
+ result. push_back (*it);
469
486
}
470
- } else if (const TupleGetItemNode* op = node.as <TupleGetItemNode>()) {
471
- stack.top ().second = true ;
472
- fpush_to_stack (op->tuple );
473
- } else {
474
- // No need to expand the children directly run visit.
475
- fvisit_leaf (node);
476
- stack.pop ();
487
+ } else if (const TupleGetItemNode* op = expr.as <TupleGetItemNode>()) {
488
+ result.push_back (op->tuple );
477
489
}
478
- }
490
+ return std::move (result);
491
+ };
492
+ ExpandDataflow (expr, fcheck_visited, fvisit_leaf, fexpand_expr);
479
493
}
480
494
481
495
void ExpandANormalForm (const LetNode* op, std::function<void (const LetNode*)> pre_visit,
0 commit comments