Skip to content

Commit df3842f

Browse files
YuchenJinjunrushao
authored andcommitted
Generic dispatching in Visitor (apache#39)
1 parent 6816cbb commit df3842f

File tree

5 files changed

+183
-112
lines changed

5 files changed

+183
-112
lines changed

include/tvm/relax/expr_functor.h

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,19 @@ class ExprFunctor<R(const Expr& n, Args...)> {
136136
}
137137
};
138138

139+
139140
/*!
140141
* \brief A simple visitor wrapper around ExprFunctor.
141142
* Recursively visit the content.
142143
*/
143-
class ExprVisitor : public ExprFunctor<void(const Expr& n)> {
144+
class ExprVisitor : public ExprFunctor<void(const Expr&)> {
144145
public:
146+
/*!
147+
* \brief Generic dispatcher for Expr.
148+
* \param expr The expr to be visited.
149+
*/
145150
void VisitExpr(const Expr& expr) override;
151+
// specific leaf level visitor functions
146152
void VisitExpr_(const ConstantNode* op) override;
147153
void VisitExpr_(const TupleNode* op) override;
148154
void VisitExpr_(const VarNode* op) override;
@@ -157,13 +163,36 @@ class ExprVisitor : public ExprFunctor<void(const Expr& n)> {
157163
void VisitExpr_(const OpNode* op) override;
158164
void VisitExpr_(const TupleGetItemNode* op) override;
159165

160-
virtual void VisitType(const Type& t);
161-
virtual void VisitSpan(const Span& span);
166+
/*!
167+
* \brief Generic dispatcher for bindings.
168+
* \param binding The binding to be visited.
169+
*/
162170
virtual void VisitBinding(const Binding& binding);
163-
virtual void VisitVarBinding(const VarBinding& binding);
164-
virtual void VisitMatchShape(const MatchShape& binding);
171+
// specific leaf level visitor functions
172+
virtual void VisitBinding_(const VarBindingNode* binding);
173+
virtual void VisitBinding_(const MatchShapeNode* binding);
174+
175+
/*!
176+
* \brief Generic dispatcher for binding blocks.
177+
* \param block The binding block to be visited.
178+
*/
165179
virtual void VisitBindingBlock(const BindingBlock& block);
166-
virtual void VisitDataflowBlock(const DataflowBlock& block);
180+
// specific leaf level visitor functions
181+
virtual void VisitBindingBlock_(const BindingBlockNode* block);
182+
virtual void VisitBindingBlock_(const DataflowBlockNode* block);
183+
184+
/*!
185+
* \brief Generic dispatcher for visiting the var definition site.
186+
* \param var The var to be visited.
187+
* \note VisitExpr_(const VarNode*) will only visit the usage site of an Var
188+
*/
189+
virtual void VisitVarDef(const Var& var);
190+
// specific leaf level visitor functions
191+
virtual void VisitVarDef_(const VarNode* var);
192+
virtual void VisitVarDef_(const DataflowVarNode* var);
193+
194+
virtual void VisitType(const Type& t);
195+
virtual void VisitSpan(const Span& span);
167196
};
168197

169198
void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
@@ -205,20 +234,35 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
205234
*/
206235
virtual Type VisitType(const Type& t);
207236

237+
/*!
238+
* \brief Generic dispatcher for bindings.
239+
* \param binding The binding to be visited.
240+
*/
208241
virtual void VisitBinding(const Binding& binding);
209-
virtual void VisitVarBinding(const VarBinding& binding);
210-
virtual void VisitMatchShape(const MatchShape& binding);
242+
// specific leaf level visitor functions
243+
virtual void VisitBinding_(const VarBindingNode* binding);
244+
virtual void VisitBinding_(const MatchShapeNode* binding);
245+
246+
/*!
247+
* \brief Generic dispatcher for binding blocks.
248+
* \param block The binding block to be visited.
249+
* \return The binding block after transformation.
250+
*/
251+
virtual BindingBlock VisitBindingBlock(const BindingBlock& block);
252+
// specific leaf level visitor functions
253+
virtual BindingBlock VisitBindingBlock_(const BindingBlockNode* block);
254+
virtual BindingBlock VisitBindingBlock_(const DataflowBlockNode* block);
211255

212256
/*!
213-
* \brief Rewrite the var definition site.
257+
* \brief Generic dispatcher for rewriting the var definition site.
214258
* \param var The var to be visited.
215259
* \return The var after post-order rewritten.
216260
* \note VisitExpr_(const VarNode*) will only visit the usage site of an Var
217261
*/
218262
virtual Var VisitVarDef(const Var& var);
219-
220-
virtual BindingBlock VisitBindingBlock(const BindingBlock& block);
221-
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block);
263+
// specific leaf level visitor functions
264+
virtual Var VisitVarDef_(const VarNode* var);
265+
virtual Var VisitVarDef_(const DataflowVarNode* var);
222266

223267
protected:
224268
class ExprNormalizer;
@@ -265,16 +309,6 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
265309
std::unordered_map<Id, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
266310
};
267311

268-
// TODO(@yuchen, @altan): Refactor to enforce dataflow mutator only rewrite stuff in dataflow blocks
269-
/*! \brief Dataflow Graph Rewriting for Custom Rewriting Passes
270-
*/
271-
class DataflowMutator : public ExprMutator {
272-
public:
273-
void VisitBinding(const Binding& binding) final;
274-
275-
virtual void VisitDataflowVarBinding(const VarBinding& binding);
276-
};
277-
278312
} // namespace relax
279313
} // namespace tvm
280314
#endif // TVM_RELAX_EXPR_FUNCTOR_H_

src/relax/backend/vm/vm_shape_lower.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class VMShapeLowerMutator : public ExprMutator {
5757
return ret_mod_;
5858
}
5959

60-
void VisitMatchShape(const MatchShape& binding) override {
60+
void VisitBinding_(const MatchShapeNode* binding) override {
6161
Expr shape = ExprMutator::VisitExpr(binding->value);
6262
static const Op& store_shape_op = Op::Get("relax.vm.builtin.store_shape");
6363
auto store_shape_attr = make_object<ShapeHeapAttrs>();

0 commit comments

Comments
 (0)