Skip to content

Commit 7133448

Browse files
authored
[VISITOR] New ExprFunctor, StmtFunctor Interface. Modular analysis (#58)
* [ARITH/VISITOR] Modular Analysis, ExprFunctor, StmtFunctor * retrigger * [IRFunctor] Migrated CodegenC * [IRFUNCTOR] Migrate CodeGenLLVM * [IRFunctor] Migrate canonical * [IRFunctor] Migrate vectorize * [IRFunctor] migrate CodeGenStackVM
1 parent e438794 commit 7133448

25 files changed

+2028
-1471
lines changed

.travis.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ after_failure:
5959
- tests/travis/travis_after_failure.sh
6060

6161
notifications:
62-
# Emails are sent to the committer's git-configured email address by default,
6362
email:
6463
on_success: change
6564
on_failure: always

include/tvm/ir_functor_ext.h

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
/*!
2+
* Copyright (c) 2017 by Contributors
3+
* \file ir_functor_ext.h
4+
* \brief More powerful Visitor that allows define function signatures.
5+
*/
6+
#ifndef TVM_IR_FUNCTOR_EXT_H_
7+
#define TVM_IR_FUNCTOR_EXT_H_
8+
9+
#include <tvm/ir_functor.h>
10+
#include "./ir.h"
11+
12+
namespace tvm {
13+
namespace ir {
14+
15+
/*!
16+
* \brief A dynamical functor that dispatches on in the first Expr argument.
17+
* You can use this as a more powerful Visitor, since it allows you to
18+
* define function signatures of Visit Function.
19+
*
20+
* \code
21+
* // A functor that set variable to b. and calculate results.
22+
* class MyExprFunctor
23+
* : public ir::ExprFunctor<int(const Expr&, int)> {
24+
* public:
25+
* int VisitExpr_(const Variable* op, int b) final {
26+
* return b;
27+
* }
28+
* int VisitExpr_(const IntImm* op, int b) final {
29+
* return op->value;
30+
* }
31+
* int VisitExpr_(const Add* op, int b) final {
32+
* return Visit(op->a, b) + Visit(op->b, b);
33+
* }
34+
* };
35+
* MyExprFunctor f;
36+
* Var x("x");
37+
* CHECK_EQ(f(x + 1, 2), 3);
38+
* \endcode
39+
*
40+
* \note Why do we need this more powerful Functor:
41+
*
42+
* We often need to implement a transformer tasks.
43+
* Say we want to take Expr and transform it to some analysis result,
44+
* This easily be done incorrectly using plain Visitor. See IRVisitor's
45+
* document for possible error cases.
46+
*
47+
* \tparam FType function signiture
48+
* This type if only defined for FType with function signiture R(const Expr&, Args...)
49+
*/
50+
template<typename FType>
51+
class ExprFunctor;
52+
/*!
53+
* \brief Same as ExprFunctor except it is applied on statements
54+
* \tparam FType The function signature.
55+
*/
56+
template<typename FType>
57+
class StmtFunctor;
58+
59+
// functions to be overriden.
60+
#define EXPR_FUNCTOR_DEFAULT { \
61+
return VisitExprDefault_(op, std::forward<Args>(args)...); \
62+
}
63+
#define STMT_FUNCTOR_DEFAULT { \
64+
return VisitStmtDefault_(op, std::forward<Args>(args)...); \
65+
}
66+
67+
#define IR_EXPR_FUNCTOR_DISPATCH(OP) \
68+
vtable.template set_dispatch<OP>( \
69+
[](const NodeRef& n, TSelf* self, Args... args) { \
70+
return self->VisitExpr_(static_cast<const OP*>(n.node_.get()), \
71+
std::forward<Args>(args)...); \
72+
}); \
73+
74+
#define IR_STMT_FUNCTOR_DISPATCH(OP) \
75+
vtable.template set_dispatch<OP>( \
76+
[](const NodeRef& n, TSelf* self, Args... args) { \
77+
return self->VisitStmt_(static_cast<const OP*>(n.node_.get()), \
78+
std::forward<Args>(args)...); \
79+
}); \
80+
81+
template<typename R, typename ...Args>
82+
class ExprFunctor<R(const Expr& n, Args...)> {
83+
private:
84+
using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
85+
using FType = IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>;
86+
87+
public:
88+
/*! \brief the result type of this functor */
89+
using result_type = R;
90+
/*! \brief virtual destructor */
91+
virtual ~ExprFunctor() {}
92+
/*!
93+
* \brief Same as call.
94+
* \param n The expression node.
95+
* \param args Additional arguments.
96+
* \return The result of the call
97+
*/
98+
R operator()(const Expr& n, Args... args) {
99+
return VisitExpr(n, std::forward<Args>(args)...);
100+
}
101+
/*!
102+
* \brief The functor call.
103+
* \param n The expression node.
104+
* \param args Additional arguments.
105+
* \return The result of the call
106+
*/
107+
virtual R VisitExpr(const Expr& n, Args... args) {
108+
static FType vtable = InitVTable();
109+
return vtable(n, this, std::forward<Args>(args)...);
110+
}
111+
// Functions that can be overriden by subclass
112+
virtual R VisitExpr_(const Variable* op, Args... args) EXPR_FUNCTOR_DEFAULT;
113+
virtual R VisitExpr_(const Load* op, Args... args) EXPR_FUNCTOR_DEFAULT;
114+
virtual R VisitExpr_(const Let* op, Args... args) EXPR_FUNCTOR_DEFAULT;
115+
virtual R VisitExpr_(const Call* op, Args... args) EXPR_FUNCTOR_DEFAULT;
116+
virtual R VisitExpr_(const Add* op, Args... args) EXPR_FUNCTOR_DEFAULT;
117+
virtual R VisitExpr_(const Sub* op, Args... args) EXPR_FUNCTOR_DEFAULT;
118+
virtual R VisitExpr_(const Mul* op, Args... args) EXPR_FUNCTOR_DEFAULT;
119+
virtual R VisitExpr_(const Div* op, Args... args) EXPR_FUNCTOR_DEFAULT;
120+
virtual R VisitExpr_(const Mod* op, Args... args) EXPR_FUNCTOR_DEFAULT;
121+
virtual R VisitExpr_(const Min* op, Args... args) EXPR_FUNCTOR_DEFAULT;
122+
virtual R VisitExpr_(const Max* op, Args... args) EXPR_FUNCTOR_DEFAULT;
123+
virtual R VisitExpr_(const EQ* op, Args... args) EXPR_FUNCTOR_DEFAULT;
124+
virtual R VisitExpr_(const NE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
125+
virtual R VisitExpr_(const LT* op, Args... args) EXPR_FUNCTOR_DEFAULT;
126+
virtual R VisitExpr_(const LE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
127+
virtual R VisitExpr_(const GT* op, Args... args) EXPR_FUNCTOR_DEFAULT;
128+
virtual R VisitExpr_(const GE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
129+
virtual R VisitExpr_(const And* op, Args... args) EXPR_FUNCTOR_DEFAULT;
130+
virtual R VisitExpr_(const Or* op, Args... args) EXPR_FUNCTOR_DEFAULT;
131+
virtual R VisitExpr_(const Reduce* op, Args... args) EXPR_FUNCTOR_DEFAULT;
132+
virtual R VisitExpr_(const Cast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
133+
virtual R VisitExpr_(const Not* op, Args... args) EXPR_FUNCTOR_DEFAULT;
134+
virtual R VisitExpr_(const Select* op, Args... args) EXPR_FUNCTOR_DEFAULT;
135+
virtual R VisitExpr_(const Ramp* op, Args... args) EXPR_FUNCTOR_DEFAULT;
136+
virtual R VisitExpr_(const Broadcast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
137+
virtual R VisitExpr_(const IntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
138+
virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
139+
virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
140+
virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
141+
virtual R VisitExprDefault_(const Node* op, Args ...) {
142+
LOG(FATAL) << "Do not have a default for " << op->type_key();
143+
return R();
144+
}
145+
146+
private:
147+
// initialize the vtable.
148+
static FType InitVTable() {
149+
FType vtable;
150+
// Set dispatch
151+
IR_EXPR_FUNCTOR_DISPATCH(Variable);
152+
IR_EXPR_FUNCTOR_DISPATCH(Load);
153+
IR_EXPR_FUNCTOR_DISPATCH(Let);
154+
IR_EXPR_FUNCTOR_DISPATCH(Call);
155+
IR_EXPR_FUNCTOR_DISPATCH(Add);
156+
IR_EXPR_FUNCTOR_DISPATCH(Sub);
157+
IR_EXPR_FUNCTOR_DISPATCH(Mul);
158+
IR_EXPR_FUNCTOR_DISPATCH(Div);
159+
IR_EXPR_FUNCTOR_DISPATCH(Mod);
160+
IR_EXPR_FUNCTOR_DISPATCH(Min);
161+
IR_EXPR_FUNCTOR_DISPATCH(Max);
162+
IR_EXPR_FUNCTOR_DISPATCH(EQ);
163+
IR_EXPR_FUNCTOR_DISPATCH(NE);
164+
IR_EXPR_FUNCTOR_DISPATCH(LT);
165+
IR_EXPR_FUNCTOR_DISPATCH(LE);
166+
IR_EXPR_FUNCTOR_DISPATCH(GT);
167+
IR_EXPR_FUNCTOR_DISPATCH(GE);
168+
IR_EXPR_FUNCTOR_DISPATCH(And);
169+
IR_EXPR_FUNCTOR_DISPATCH(Or);
170+
IR_EXPR_FUNCTOR_DISPATCH(Reduce);
171+
IR_EXPR_FUNCTOR_DISPATCH(Cast);
172+
IR_EXPR_FUNCTOR_DISPATCH(Not);
173+
IR_EXPR_FUNCTOR_DISPATCH(Select);
174+
IR_EXPR_FUNCTOR_DISPATCH(Ramp);
175+
IR_EXPR_FUNCTOR_DISPATCH(Broadcast);
176+
IR_EXPR_FUNCTOR_DISPATCH(IntImm);
177+
IR_EXPR_FUNCTOR_DISPATCH(UIntImm);
178+
IR_EXPR_FUNCTOR_DISPATCH(FloatImm);
179+
IR_EXPR_FUNCTOR_DISPATCH(StringImm);
180+
return vtable;
181+
}
182+
};
183+
184+
template<typename R, typename ...Args>
185+
class StmtFunctor<R(const Stmt& n, Args... args)> {
186+
private:
187+
using TSelf = StmtFunctor<R(const Stmt& n, Args... args)>;
188+
using FType = IRFunctor<R(const NodeRef& n, TSelf* self, Args... args)>;
189+
190+
public:
191+
/*! \brief the result type of this functor */
192+
using result_type = R;
193+
/*! \brief virtual destructor */
194+
virtual ~StmtFunctor() {}
195+
/*!
196+
* \brief Same as call.
197+
* \param n The stmt node.
198+
* \param args Additional arguments.
199+
* \return The result of the call
200+
*/
201+
R operator()(const Stmt& n, Args... args) {
202+
return VisitStmt(n, std::forward<Args>(args)...);
203+
}
204+
/*!
205+
* \brief The functor call.
206+
* \param n The stmt node.
207+
* \param args Additional arguments.
208+
* \return The result of the call
209+
*/
210+
virtual R VisitStmt(const Stmt& n, Args... args) {
211+
static FType vtable = InitVTable();
212+
return vtable(n, this, std::forward<Args>(args)...);
213+
}
214+
// Functions that can be overriden by subclass
215+
virtual R VisitStmt_(const LetStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
216+
virtual R VisitStmt_(const AttrStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
217+
virtual R VisitStmt_(const IfThenElse* op, Args... args) STMT_FUNCTOR_DEFAULT;
218+
virtual R VisitStmt_(const For* op, Args... args) STMT_FUNCTOR_DEFAULT;
219+
virtual R VisitStmt_(const Allocate* op, Args... args) STMT_FUNCTOR_DEFAULT;
220+
virtual R VisitStmt_(const Store* op, Args... args) STMT_FUNCTOR_DEFAULT;
221+
virtual R VisitStmt_(const Free* op, Args... args) STMT_FUNCTOR_DEFAULT;
222+
virtual R VisitStmt_(const AssertStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
223+
virtual R VisitStmt_(const ProducerConsumer* op, Args... args) STMT_FUNCTOR_DEFAULT;
224+
virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT;
225+
virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT;
226+
virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT;
227+
virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT;
228+
virtual R VisitStmtDefault_(const Node* op, Args ...) {
229+
LOG(FATAL) << "Do not have a default for " << op->type_key();
230+
return R();
231+
}
232+
233+
private:
234+
// initialize the vtable.
235+
static FType InitVTable() {
236+
FType vtable;
237+
IR_STMT_FUNCTOR_DISPATCH(LetStmt);
238+
IR_STMT_FUNCTOR_DISPATCH(AttrStmt);
239+
IR_STMT_FUNCTOR_DISPATCH(IfThenElse);
240+
IR_STMT_FUNCTOR_DISPATCH(For);
241+
IR_STMT_FUNCTOR_DISPATCH(Allocate);
242+
IR_STMT_FUNCTOR_DISPATCH(Store);
243+
IR_STMT_FUNCTOR_DISPATCH(Free);
244+
IR_STMT_FUNCTOR_DISPATCH(AssertStmt);
245+
IR_STMT_FUNCTOR_DISPATCH(ProducerConsumer);
246+
IR_STMT_FUNCTOR_DISPATCH(Provide);
247+
IR_STMT_FUNCTOR_DISPATCH(Realize);
248+
IR_STMT_FUNCTOR_DISPATCH(Block);
249+
IR_STMT_FUNCTOR_DISPATCH(Evaluate);
250+
return vtable;
251+
}
252+
};
253+
254+
#undef IR_STMT_FUNCTOR_DISPATCH
255+
#undef IR_EXPR_FUNCTOR_DISPATCH
256+
#undef EXPR_FUNCTOR_DEFAULT
257+
#undef STMT_FUNCTOR_DEFAULT
258+
259+
} // namespace ir
260+
} // namespace tvm
261+
#endif // TVM_IR_FUNCTOR_EXT_H_

include/tvm/ir_mutator.h

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -55,59 +55,23 @@ class IRMutator {
5555
static FMutateStmt& vtable_stmt(); // NOLINT(*)
5656
// Set of overloadable functions
5757
// The underscore allows Mutate not to be shadowed by inheritance
58-
virtual Stmt Mutate_(const Variable* op, const Stmt& s);
5958
virtual Stmt Mutate_(const LetStmt* op, const Stmt& s);
6059
virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s);
6160
virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
6261
virtual Stmt Mutate_(const For* op, const Stmt& s);
6362
virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
64-
virtual Stmt Mutate_(const Load* op, const Stmt& s);
6563
virtual Stmt Mutate_(const Store* op, const Stmt& s);
66-
virtual Stmt Mutate_(const Let* op, const Stmt& s);
6764
virtual Stmt Mutate_(const Free* op, const Stmt& s);
68-
virtual Stmt Mutate_(const Call* op, const Stmt& s);
69-
virtual Stmt Mutate_(const Add* op, const Stmt& e);
70-
virtual Stmt Mutate_(const Sub* op, const Stmt& e);
71-
virtual Stmt Mutate_(const Mul* op, const Stmt& e);
72-
virtual Stmt Mutate_(const Div* op, const Stmt& e);
73-
virtual Stmt Mutate_(const Mod* op, const Stmt& e);
74-
virtual Stmt Mutate_(const Min* op, const Stmt& e);
75-
virtual Stmt Mutate_(const Max* op, const Stmt& e);
76-
virtual Stmt Mutate_(const EQ* op, const Stmt& e);
77-
virtual Stmt Mutate_(const NE* op, const Stmt& e);
78-
virtual Stmt Mutate_(const LT* op, const Stmt& e);
79-
virtual Stmt Mutate_(const LE* op, const Stmt& e);
80-
virtual Stmt Mutate_(const GT* op, const Stmt& e);
81-
virtual Stmt Mutate_(const GE* op, const Stmt& e);
82-
virtual Stmt Mutate_(const And* op, const Stmt& e);
83-
virtual Stmt Mutate_(const Or* op, const Stmt& e);
84-
virtual Stmt Mutate_(const Reduce* op, const Stmt& s);
85-
virtual Stmt Mutate_(const Cast* op, const Stmt& s);
86-
virtual Stmt Mutate_(const Not* op, const Stmt& s);
87-
virtual Stmt Mutate_(const Select* op, const Stmt& s);
88-
virtual Stmt Mutate_(const Ramp* op, const Stmt& s);
89-
virtual Stmt Mutate_(const Broadcast* op, const Stmt& e);
9065
virtual Stmt Mutate_(const AssertStmt* op, const Stmt& e);
9166
virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& e);
9267
virtual Stmt Mutate_(const Provide* op, const Stmt& e);
9368
virtual Stmt Mutate_(const Realize* op, const Stmt& s);
9469
virtual Stmt Mutate_(const Block* op, const Stmt& s);
9570
virtual Stmt Mutate_(const Evaluate* op, const Stmt& e);
96-
virtual Stmt Mutate_(const IntImm* op, const Stmt& e);
97-
virtual Stmt Mutate_(const UIntImm* op, const Stmt& e);
98-
virtual Stmt Mutate_(const FloatImm* op, const Stmt& e);
99-
virtual Stmt Mutate_(const StringImm* op, const Stmt& e);
10071

10172
virtual Expr Mutate_(const Variable* op, const Expr& e);
102-
virtual Expr Mutate_(const LetStmt* op, const Expr& e);
103-
virtual Expr Mutate_(const AttrStmt* op, const Expr& e);
104-
virtual Expr Mutate_(const IfThenElse* op, const Expr& e);
105-
virtual Expr Mutate_(const For* op, const Expr& e);
106-
virtual Expr Mutate_(const Allocate* op, const Expr& e);
10773
virtual Expr Mutate_(const Load* op, const Expr& e);
108-
virtual Expr Mutate_(const Store* op, const Expr& e);
10974
virtual Expr Mutate_(const Let* op, const Expr& e);
110-
virtual Expr Mutate_(const Free* op, const Expr& e);
11175
virtual Expr Mutate_(const Call* op, const Expr& e);
11276
virtual Expr Mutate_(const Add* op, const Expr& e);
11377
virtual Expr Mutate_(const Sub* op, const Expr& e);
@@ -130,38 +94,12 @@ class IRMutator {
13094
virtual Expr Mutate_(const Select* op, const Expr& e);
13195
virtual Expr Mutate_(const Ramp* op, const Expr& e);
13296
virtual Expr Mutate_(const Broadcast* op, const Expr& e);
133-
virtual Expr Mutate_(const AssertStmt* op, const Expr& e);
134-
virtual Expr Mutate_(const ProducerConsumer* op, const Expr& e);
135-
virtual Expr Mutate_(const Provide* op, const Expr& e);
136-
virtual Expr Mutate_(const Realize* op, const Expr& e);
137-
virtual Expr Mutate_(const Block* op, const Expr& e);
138-
virtual Expr Mutate_(const Evaluate* op, const Expr& e);
13997
virtual Expr Mutate_(const IntImm* op, const Expr& e);
14098
virtual Expr Mutate_(const UIntImm* op, const Expr& e);
14199
virtual Expr Mutate_(const FloatImm* op, const Expr& e);
142100
virtual Expr Mutate_(const StringImm* op, const Expr& e);
143101
};
144102

145-
/*!
146-
* \brief Example on how to subclass and override behavior of IRMutator
147-
*/
148-
class IRMutatorExample : public IRMutator {
149-
public:
150-
Expr Mutate(Expr expr) final {
151-
static const FMutateExpr& f = IRMutatorExample::vtable_expr();
152-
return (f.can_dispatch(expr) ?
153-
f(expr, expr, this) : IRMutator::Mutate(expr));
154-
}
155-
Stmt Mutate(Stmt stmt) final {
156-
static const FMutateStmt& f = IRMutatorExample::vtable_stmt();
157-
return (f.can_dispatch(stmt) ?
158-
f(stmt, stmt, this) : IRMutator::Mutate(stmt));
159-
}
160-
// to be implemented by child class
161-
static FMutateExpr& vtable_expr(); // NOLINT(*)
162-
static FMutateStmt& vtable_stmt(); // NOLINT(*)
163-
};
164-
165103
} // namespace ir
166104
} // namespace tvm
167105
#endif // TVM_IR_MUTATOR_H_

0 commit comments

Comments
 (0)