@@ -139,9 +139,6 @@ class ExprFunctor<R(const Expr& n, Args...)> {
139139/* !
140140 * \brief A simple visitor wrapper around ExprFunctor.
141141 * Recursively visit the content.
142- *
143- * ExprVisitor treats Expr as dataflow graph,
144- * and only visit each Expr node once.
145142 */
146143class ExprVisitor : public ExprFunctor <void (const Expr& n)> {
147144 public:
@@ -167,9 +164,6 @@ class ExprVisitor : public ExprFunctor<void(const Expr& n)> {
167164 virtual void VisitMatchShape (const MatchShape& binding);
168165 virtual void VisitBindingBlock (const BindingBlock& block);
169166 virtual void VisitDataflowBlock (const DataflowBlock& block);
170-
171- protected:
172- std::unordered_map<const Object*, size_t > visit_counter_;
173167};
174168
175169void PostOrderVisit (const Expr& node, std::function<void (const Expr&)> fvisit);
@@ -221,19 +215,48 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
221215 virtual Type VisitType (const Type& t);
222216
223217 virtual void VisitBinding (const Binding& binding);
224- virtual Var VisitVarBinding (const VarBinding& binding);
218+ virtual void VisitVarBinding (const VarBinding& binding);
225219 virtual void VisitMatchShape (const MatchShape& binding);
226220
227221 virtual BindingBlock VisitBindingBlock (const BindingBlock& block);
228222 virtual BindingBlock VisitDataflowBlock (const DataflowBlock& block);
229223
230224 protected:
231225 Expr MutateWithPrologue (const Expr& expr, bool is_dataflow);
232- /* ! \brief Look up the value binded to a var. */
226+
227+ /* ! \brief Look up the value of a variable. If the variable is bound, then returns the bound
228+ * value. Otherwise, returns the rewritten expression for the variable.
229+ */
233230 Expr LookupVar (Var var);
234- // A remapping table: pre var -> post var
235- std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
236- std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> memo_;
231+
232+ inline void UpdateMemo (Expr pre , Expr post ) {
233+ if (const VarNode* var = pre .as <VarNode>()) {
234+ var_memo_[var->vid ] = post ;
235+ } else {
236+ expr_memo_[pre ] = post ;
237+ }
238+ }
239+
240+ inline Optional<Expr> LookupMemo (Expr pre ) {
241+ if (pre .as <VarNode>()) {
242+ Id vid = Downcast<Var>(pre )->vid ;
243+ if (var_memo_.count (vid)) {
244+ return var_memo_[vid];
245+ }
246+ } else {
247+ if (expr_memo_.count (pre )) {
248+ return expr_memo_[pre ];
249+ }
250+ }
251+ return NullOpt;
252+ }
253+
254+ /* ! \brief Variable memoization table using Id equality */
255+ std::unordered_map<Id, Expr, ObjectPtrHash, ObjectPtrEqual> var_memo_;
256+
257+ /* ! \brief Expr memoization table using pointer equality */
258+ std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> expr_memo_;
259+
237260 std::shared_ptr<NameTable> name_table_;
238261 BlockBuilder builder_;
239262};
@@ -245,7 +268,7 @@ class DataflowMutator : public ExprMutator {
245268 public:
246269 void VisitBinding (const Binding& binding) final ;
247270
248- virtual Var VisitDataflowVarBinding (const VarBinding& binding);
271+ virtual void VisitDataflowVarBinding (const VarBinding& binding);
249272};
250273
251274} // namespace relax
0 commit comments