|
34 | 34 | #include "mlir/IR/BuiltinDialect.h" |
35 | 35 | #include "mlir/IR/BuiltinTypes.h" |
36 | 36 | #include "mlir/IR/IRMapping.h" |
| 37 | +#include "mlir/IR/Operation.h" |
37 | 38 | #include "mlir/IR/Value.h" |
38 | 39 | #include "mlir/Pass/Pass.h" |
39 | 40 | #include "mlir/Pass/PassManager.h" |
@@ -111,12 +112,86 @@ class CIRPtrStrideOpLowering |
111 | 112 | class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> { |
112 | 113 | public: |
113 | 114 | using mlir::OpConversionPattern<mlir::cir::LoopOp>::OpConversionPattern; |
| 115 | + using LoopKind = mlir::cir::LoopOpKind; |
| 116 | + |
| 117 | + mlir::LogicalResult |
| 118 | + fetchCondRegionYields(mlir::Region &condRegion, |
| 119 | + mlir::cir::YieldOp &yieldToBody, |
| 120 | + mlir::cir::YieldOp &yieldToCont) const { |
| 121 | + for (auto &bb : condRegion) { |
| 122 | + if (auto yieldOp = dyn_cast<mlir::cir::YieldOp>(bb.getTerminator())) { |
| 123 | + if (!yieldOp.getKind().has_value()) |
| 124 | + yieldToCont = yieldOp; |
| 125 | + else if (yieldOp.getKind() == mlir::cir::YieldOpKind::Continue) |
| 126 | + yieldToBody = yieldOp; |
| 127 | + else |
| 128 | + return mlir::failure(); |
| 129 | + } |
| 130 | + } |
| 131 | + |
| 132 | + // Succeed only if both yields are found. |
| 133 | + if (!yieldToBody || !yieldToCont) |
| 134 | + return mlir::failure(); |
| 135 | + return mlir::success(); |
| 136 | + } |
| 137 | + |
| 138 | + mlir::LogicalResult |
| 139 | + rewriteWhileLoop(mlir::cir::LoopOp loopOp, OpAdaptor adaptor, |
| 140 | + mlir::ConversionPatternRewriter &rewriter) const { |
| 141 | + auto *currentBlock = rewriter.getInsertionBlock(); |
| 142 | + auto *continueBlock = |
| 143 | + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); |
| 144 | + |
| 145 | + // Fetch required info from the condition region. |
| 146 | + auto &condRegion = loopOp.getCond(); |
| 147 | + auto &condFrontBlock = condRegion.front(); |
| 148 | + mlir::cir::YieldOp yieldToBody, yieldToCont; |
| 149 | + if (fetchCondRegionYields(condRegion, yieldToBody, yieldToCont).failed()) |
| 150 | + return loopOp.emitError("failed to fetch yields in cond region"); |
| 151 | + |
| 152 | + // Fetch required info from the condition region. |
| 153 | + auto &bodyRegion = loopOp.getBody(); |
| 154 | + auto &bodyFrontBlock = bodyRegion.front(); |
| 155 | + auto bodyYield = |
| 156 | + dyn_cast<mlir::cir::YieldOp>(bodyRegion.back().getTerminator()); |
| 157 | + assert(bodyYield && "unstructured while loops are NYI"); |
| 158 | + |
| 159 | + // Move loop op region contents to current CFG. |
| 160 | + rewriter.inlineRegionBefore(condRegion, continueBlock); |
| 161 | + rewriter.inlineRegionBefore(bodyRegion, continueBlock); |
| 162 | + |
| 163 | + // Set loop entry point to condition block. |
| 164 | + rewriter.setInsertionPointToEnd(currentBlock); |
| 165 | + rewriter.create<mlir::cir::BrOp>(loopOp.getLoc(), &condFrontBlock); |
| 166 | + |
| 167 | + // Set loop exit point to continue block. |
| 168 | + rewriter.setInsertionPoint(yieldToCont); |
| 169 | + rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(yieldToCont, continueBlock); |
| 170 | + |
| 171 | + // Branch from condition to body. |
| 172 | + rewriter.setInsertionPoint(yieldToBody); |
| 173 | + rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(yieldToBody, &bodyFrontBlock); |
| 174 | + |
| 175 | + // Branch from body to condition. |
| 176 | + rewriter.setInsertionPoint(bodyYield); |
| 177 | + rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(bodyYield, &condFrontBlock); |
| 178 | + |
| 179 | + // Remove the loop op. |
| 180 | + rewriter.eraseOp(loopOp); |
| 181 | + return mlir::success(); |
| 182 | + } |
114 | 183 |
|
115 | 184 | mlir::LogicalResult |
116 | 185 | matchAndRewrite(mlir::cir::LoopOp loopOp, OpAdaptor adaptor, |
117 | 186 | mlir::ConversionPatternRewriter &rewriter) const override { |
118 | | - if (loopOp.getKind() != mlir::cir::LoopOpKind::For) |
| 187 | + switch (loopOp.getKind()) { |
| 188 | + case LoopKind::For: |
| 189 | + break; |
| 190 | + case LoopKind::While: |
| 191 | + return rewriteWhileLoop(loopOp, adaptor, rewriter); |
| 192 | + case LoopKind::DoWhile: |
119 | 193 | llvm_unreachable("NYI"); |
| 194 | + } |
120 | 195 |
|
121 | 196 | auto loc = loopOp.getLoc(); |
122 | 197 |
|
|
0 commit comments