@@ -136,13 +136,19 @@ class ExprFunctor<R(const Expr& n, Args...)> {
136
136
}
137
137
};
138
138
139
+
139
140
/* !
140
141
* \brief A simple visitor wrapper around ExprFunctor.
141
142
* Recursively visit the content.
142
143
*/
143
- class ExprVisitor : public ExprFunctor <void (const Expr& n )> {
144
+ class ExprVisitor : public ExprFunctor <void (const Expr&)> {
144
145
public:
146
+ /* !
147
+ * \brief Generic dispatcher for Expr.
148
+ * \param expr The expr to be visited.
149
+ */
145
150
void VisitExpr (const Expr& expr) override ;
151
+ // specific leaf level visitor functions
146
152
void VisitExpr_ (const ConstantNode* op) override ;
147
153
void VisitExpr_ (const TupleNode* op) override ;
148
154
void VisitExpr_ (const VarNode* op) override ;
@@ -157,13 +163,36 @@ class ExprVisitor : public ExprFunctor<void(const Expr& n)> {
157
163
void VisitExpr_ (const OpNode* op) override ;
158
164
void VisitExpr_ (const TupleGetItemNode* op) override ;
159
165
160
- virtual void VisitType (const Type& t);
161
- virtual void VisitSpan (const Span& span);
166
+ /* !
167
+ * \brief Generic dispatcher for bindings.
168
+ * \param binding The binding to be visited.
169
+ */
162
170
virtual void VisitBinding (const Binding& binding);
163
- virtual void VisitVarBinding (const VarBinding& binding);
164
- virtual void VisitMatchShape (const MatchShape& binding);
171
+ // specific leaf level visitor functions
172
+ virtual void VisitBinding_ (const VarBindingNode* binding);
173
+ virtual void VisitBinding_ (const MatchShapeNode* binding);
174
+
175
+ /* !
176
+ * \brief Generic dispatcher for binding blocks.
177
+ * \param block The binding block to be visited.
178
+ */
165
179
virtual void VisitBindingBlock (const BindingBlock& block);
166
- virtual void VisitDataflowBlock (const DataflowBlock& block);
180
+ // specific leaf level visitor functions
181
+ virtual void VisitBindingBlock_ (const BindingBlockNode* block);
182
+ virtual void VisitBindingBlock_ (const DataflowBlockNode* block);
183
+
184
+ /* !
185
+ * \brief Generic dispatcher for visiting the var definition site.
186
+ * \param var The var to be visited.
187
+ * \note VisitExpr_(const VarNode*) will only visit the usage site of an Var
188
+ */
189
+ virtual void VisitVarDef (const Var& var);
190
+ // specific leaf level visitor functions
191
+ virtual void VisitVarDef_ (const VarNode* var);
192
+ virtual void VisitVarDef_ (const DataflowVarNode* var);
193
+
194
+ virtual void VisitType (const Type& t);
195
+ virtual void VisitSpan (const Span& span);
167
196
};
168
197
169
198
void PostOrderVisit (const Expr& node, std::function<void (const Expr&)> fvisit);
@@ -205,20 +234,35 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
205
234
*/
206
235
virtual Type VisitType (const Type& t);
207
236
237
+ /* !
238
+ * \brief Generic dispatcher for bindings.
239
+ * \param binding The binding to be visited.
240
+ */
208
241
virtual void VisitBinding (const Binding& binding);
209
- virtual void VisitVarBinding (const VarBinding& binding);
210
- virtual void VisitMatchShape (const MatchShape& binding);
242
+ // specific leaf level visitor functions
243
+ virtual void VisitBinding_ (const VarBindingNode* binding);
244
+ virtual void VisitBinding_ (const MatchShapeNode* binding);
245
+
246
+ /* !
247
+ * \brief Generic dispatcher for binding blocks.
248
+ * \param block The binding block to be visited.
249
+ * \return The binding block after transformation.
250
+ */
251
+ virtual BindingBlock VisitBindingBlock (const BindingBlock& block);
252
+ // specific leaf level visitor functions
253
+ virtual BindingBlock VisitBindingBlock_ (const BindingBlockNode* block);
254
+ virtual BindingBlock VisitBindingBlock_ (const DataflowBlockNode* block);
211
255
212
256
/* !
213
- * \brief Rewrite the var definition site.
257
+ * \brief Generic dispatcher for rewriting the var definition site.
214
258
* \param var The var to be visited.
215
259
* \return The var after post-order rewritten.
216
260
* \note VisitExpr_(const VarNode*) will only visit the usage site of an Var
217
261
*/
218
262
virtual Var VisitVarDef (const Var& var);
219
-
220
- virtual BindingBlock VisitBindingBlock (const BindingBlock& block );
221
- virtual BindingBlock VisitDataflowBlock (const DataflowBlock& block );
263
+ // specific leaf level visitor functions
264
+ virtual Var VisitVarDef_ (const VarNode* var );
265
+ virtual Var VisitVarDef_ (const DataflowVarNode* var );
222
266
223
267
protected:
224
268
class ExprNormalizer ;
@@ -265,16 +309,6 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
265
309
std::unordered_map<Id, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
266
310
};
267
311
268
- // TODO(@yuchen, @altan): Refactor to enforce dataflow mutator only rewrite stuff in dataflow blocks
269
- /* ! \brief Dataflow Graph Rewriting for Custom Rewriting Passes
270
- */
271
- class DataflowMutator : public ExprMutator {
272
- public:
273
- void VisitBinding (const Binding& binding) final ;
274
-
275
- virtual void VisitDataflowVarBinding (const VarBinding& binding);
276
- };
277
-
278
312
} // namespace relax
279
313
} // namespace tvm
280
314
#endif // TVM_RELAX_EXPR_FUNCTOR_H_
0 commit comments