@@ -136,13 +136,19 @@ class ExprFunctor<R(const Expr& n, Args...)> {
136136 }
137137};
138138
139+
139140/* !
140141 * \brief A simple visitor wrapper around ExprFunctor.
141142 * Recursively visit the content.
142143 */
143- class ExprVisitor : public ExprFunctor <void (const Expr& n )> {
144+ class ExprVisitor : public ExprFunctor <void (const Expr&)> {
144145 public:
146+ /* !
147+ * \brief Generic dispatcher for Expr.
148+ * \param expr The expr to be visited.
149+ */
145150 void VisitExpr (const Expr& expr) override ;
151+ // specific leaf level visitor functions
146152 void VisitExpr_ (const ConstantNode* op) override ;
147153 void VisitExpr_ (const TupleNode* op) override ;
148154 void VisitExpr_ (const VarNode* op) override ;
@@ -157,13 +163,36 @@ class ExprVisitor : public ExprFunctor<void(const Expr& n)> {
157163 void VisitExpr_ (const OpNode* op) override ;
158164 void VisitExpr_ (const TupleGetItemNode* op) override ;
159165
160- virtual void VisitType (const Type& t);
161- virtual void VisitSpan (const Span& span);
166+ /* !
167+ * \brief Generic dispatcher for bindings.
168+ * \param binding The binding to be visited.
169+ */
162170 virtual void VisitBinding (const Binding& binding);
163- virtual void VisitVarBinding (const VarBinding& binding);
164- virtual void VisitMatchShape (const MatchShape& binding);
171+ // specific leaf level visitor functions
172+ virtual void VisitBinding_ (const VarBindingNode* binding);
173+ virtual void VisitBinding_ (const MatchShapeNode* binding);
174+
175+ /* !
176+ * \brief Generic dispatcher for binding blocks.
177+ * \param block The binding block to be visited.
178+ */
165179 virtual void VisitBindingBlock (const BindingBlock& block);
166- virtual void VisitDataflowBlock (const DataflowBlock& block);
180+ // specific leaf level visitor functions
181+ virtual void VisitBindingBlock_ (const BindingBlockNode* block);
182+ virtual void VisitBindingBlock_ (const DataflowBlockNode* block);
183+
184+ /* !
185+ * \brief Generic dispatcher for visiting the var definition site.
186+ * \param var The var to be visited.
187+ * \note VisitExpr_(const VarNode*) will only visit the usage site of an Var
188+ */
189+ virtual void VisitVarDef (const Var& var);
190+ // specific leaf level visitor functions
191+ virtual void VisitVarDef_ (const VarNode* var);
192+ virtual void VisitVarDef_ (const DataflowVarNode* var);
193+
194+ virtual void VisitType (const Type& t);
195+ virtual void VisitSpan (const Span& span);
167196};
168197
169198void PostOrderVisit (const Expr& node, std::function<void (const Expr&)> fvisit);
@@ -205,20 +234,35 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
205234 */
206235 virtual Type VisitType (const Type& t);
207236
237+ /* !
238+ * \brief Generic dispatcher for bindings.
239+ * \param binding The binding to be visited.
240+ */
208241 virtual void VisitBinding (const Binding& binding);
209- virtual void VisitVarBinding (const VarBinding& binding);
210- virtual void VisitMatchShape (const MatchShape& binding);
242+ // specific leaf level visitor functions
243+ virtual void VisitBinding_ (const VarBindingNode* binding);
244+ virtual void VisitBinding_ (const MatchShapeNode* binding);
245+
246+ /* !
247+ * \brief Generic dispatcher for binding blocks.
248+ * \param block The binding block to be visited.
249+ * \return The binding block after transformation.
250+ */
251+ virtual BindingBlock VisitBindingBlock (const BindingBlock& block);
252+ // specific leaf level visitor functions
253+ virtual BindingBlock VisitBindingBlock_ (const BindingBlockNode* block);
254+ virtual BindingBlock VisitBindingBlock_ (const DataflowBlockNode* block);
211255
212256 /* !
213- * \brief Rewrite the var definition site.
257+ * \brief Generic dispatcher for rewriting the var definition site.
214258 * \param var The var to be visited.
215259 * \return The var after post-order rewritten.
216260 * \note VisitExpr_(const VarNode*) will only visit the usage site of an Var
217261 */
218262 virtual Var VisitVarDef (const Var& var);
219-
220- virtual BindingBlock VisitBindingBlock (const BindingBlock& block );
221- virtual BindingBlock VisitDataflowBlock (const DataflowBlock& block );
263+ // specific leaf level visitor functions
264+ virtual Var VisitVarDef_ (const VarNode* var );
265+ virtual Var VisitVarDef_ (const DataflowVarNode* var );
222266
223267 protected:
224268 class ExprNormalizer ;
@@ -265,16 +309,6 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
265309 std::unordered_map<Id, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
266310};
267311
268- // TODO(@yuchen, @altan): Refactor to enforce dataflow mutator only rewrite stuff in dataflow blocks
269- /* ! \brief Dataflow Graph Rewriting for Custom Rewriting Passes
270- */
271- class DataflowMutator : public ExprMutator {
272- public:
273- void VisitBinding (const Binding& binding) final ;
274-
275- virtual void VisitDataflowVarBinding (const VarBinding& binding);
276- };
277-
278312} // namespace relax
279313} // namespace tvm
280314#endif // TVM_RELAX_EXPR_FUNCTOR_H_
0 commit comments