@@ -30,6 +30,9 @@ namespace mlir {
3030// / functions in your class. This class is defined in terms of statically
3131// / resolved overloading, not virtual functions.
3232// /
33+ // / The visitor is templated on its return type (`RetTy`). With a WalkResult
34+ // / return type, the visitor supports interrupting walks.
35+ // /
3336// / For example, here is a visitor that counts the number of for AffineDimExprs
3437// / in an AffineExpr.
3538// /
@@ -65,7 +68,6 @@ namespace mlir {
6568// / virtual function call overhead. Defining and using a AffineExprVisitor is
6669// / just as efficient as having your own switch instruction over the instruction
6770// / opcode.
68-
6971template <typename SubClass, typename RetTy>
7072class AffineExprVisitorBase {
7173public:
@@ -136,6 +138,8 @@ class AffineExprVisitorBase {
136138 RetTy visitSymbolExpr (AffineSymbolExpr expr) { return RetTy (); }
137139};
138140
141+ // / See documentation for AffineExprVisitorBase. This visitor supports
142+ // / interrupting walks when a `WalkResult` is used for `RetTy`.
139143template <typename SubClass, typename RetTy = void >
140144class AffineExprVisitor : public AffineExprVisitorBase <SubClass, RetTy> {
141145 // ===--------------------------------------------------------------------===//
@@ -150,27 +154,52 @@ class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
150154 switch (expr.getKind ()) {
151155 case AffineExprKind::Add: {
152156 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
153- walkOperandsPostOrder (binOpExpr);
157+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
158+ if (walkOperandsPostOrder (binOpExpr).wasInterrupted ())
159+ return WalkResult::interrupt ();
160+ } else {
161+ walkOperandsPostOrder (binOpExpr);
162+ }
154163 return self->visitAddExpr (binOpExpr);
155164 }
156165 case AffineExprKind::Mul: {
157166 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
158- walkOperandsPostOrder (binOpExpr);
167+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
168+ if (walkOperandsPostOrder (binOpExpr).wasInterrupted ())
169+ return WalkResult::interrupt ();
170+ } else {
171+ walkOperandsPostOrder (binOpExpr);
172+ }
159173 return self->visitMulExpr (binOpExpr);
160174 }
161175 case AffineExprKind::Mod: {
162176 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
163- walkOperandsPostOrder (binOpExpr);
177+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
178+ if (walkOperandsPostOrder (binOpExpr).wasInterrupted ())
179+ return WalkResult::interrupt ();
180+ } else {
181+ walkOperandsPostOrder (binOpExpr);
182+ }
164183 return self->visitModExpr (binOpExpr);
165184 }
166185 case AffineExprKind::FloorDiv: {
167186 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
168- walkOperandsPostOrder (binOpExpr);
187+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
188+ if (walkOperandsPostOrder (binOpExpr).wasInterrupted ())
189+ return WalkResult::interrupt ();
190+ } else {
191+ walkOperandsPostOrder (binOpExpr);
192+ }
169193 return self->visitFloorDivExpr (binOpExpr);
170194 }
171195 case AffineExprKind::CeilDiv: {
172196 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
173- walkOperandsPostOrder (binOpExpr);
197+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
198+ if (walkOperandsPostOrder (binOpExpr).wasInterrupted ())
199+ return WalkResult::interrupt ();
200+ } else {
201+ walkOperandsPostOrder (binOpExpr);
202+ }
174203 return self->visitCeilDivExpr (binOpExpr);
175204 }
176205 case AffineExprKind::Constant:
@@ -186,8 +215,19 @@ class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
186215private:
187216 // Walk the operands - each operand is itself walked in post order.
188217 RetTy walkOperandsPostOrder (AffineBinaryOpExpr expr) {
189- walkPostOrder (expr.getLHS ());
190- walkPostOrder (expr.getRHS ());
218+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
219+ if (walkPostOrder (expr.getLHS ()).wasInterrupted ())
220+ return WalkResult::interrupt ();
221+ } else {
222+ walkPostOrder (expr.getLHS ());
223+ }
224+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
225+ if (walkPostOrder (expr.getLHS ()).wasInterrupted ())
226+ return WalkResult::interrupt ();
227+ return WalkResult::advance ();
228+ } else {
229+ return walkPostOrder (expr.getRHS ());
230+ }
191231 }
192232};
193233
0 commit comments