Skip to content

Commit dac46fa

Browse files
committed
[CIR][Lowering] Fixed break/continue lowering for loops
1 parent dc8fbcf commit dac46fa

File tree

3 files changed

+947
-4
lines changed

3 files changed

+947
-4
lines changed

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,42 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
241241
return mlir::success();
242242
}
243243

244+
void makeYieldIf(mlir::cir::YieldOpKind kind, mlir::cir::YieldOp &op, mlir::Block *to,
245+
mlir::ConversionPatternRewriter &rewriter) const {
246+
if (op.getKind() == kind) {
247+
rewriter.setInsertionPoint(op);
248+
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(op, op.getArgs(), to);
249+
}
250+
}
251+
252+
void lowerNestedBreakContinue(mlir::Region &loopBody, mlir::Block *exitBlock,
253+
mlir::Block *continueBlock,
254+
mlir::ConversionPatternRewriter &rewriter) const {
255+
256+
auto processBreak = [&](mlir::Operation *op) {
257+
if (isa<mlir::cir::LoopOp, mlir::cir::SwitchOp>(*op)) // don't process breaks in nested loops and switches
258+
return mlir::WalkResult::skip();
259+
260+
if (auto yield = dyn_cast<mlir::cir::YieldOp>(*op))
261+
makeYieldIf(mlir::cir::YieldOpKind::Break, yield, exitBlock, rewriter);
262+
263+
return mlir::WalkResult::advance();
264+
};
265+
266+
auto processContinue = [&](mlir::Operation *op) {
267+
if (isa<mlir::cir::LoopOp>(*op)) // don't process continues in nested loops
268+
return mlir::WalkResult::skip();
269+
270+
if (auto yield = dyn_cast<mlir::cir::YieldOp>(*op))
271+
makeYieldIf(mlir::cir::YieldOpKind::Continue, yield, continueBlock, rewriter);
272+
273+
return mlir::WalkResult::advance();
274+
};
275+
276+
loopBody.walk<mlir::WalkOrder::PreOrder>(processBreak);
277+
loopBody.walk<mlir::WalkOrder::PreOrder>(processContinue);
278+
}
279+
244280
mlir::LogicalResult
245281
matchAndRewrite(mlir::cir::LoopOp loopOp, OpAdaptor adaptor,
246282
mlir::ConversionPatternRewriter &rewriter) const override {
@@ -268,6 +304,9 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
268304
auto &stepFrontBlock = stepRegion.front();
269305
auto stepYield =
270306
dyn_cast<mlir::cir::YieldOp>(stepRegion.back().getTerminator());
307+
auto &stepBlock = (kind == LoopKind::For ? stepFrontBlock : condFrontBlock);
308+
309+
lowerNestedBreakContinue(bodyRegion, continueBlock, &stepBlock, rewriter);
271310

272311
// Move loop op region contents to current CFG.
273312
rewriter.inlineRegionBefore(condRegion, continueBlock);
@@ -290,8 +329,7 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
290329

291330
// Branch from body to condition or to step on for-loop cases.
292331
rewriter.setInsertionPoint(bodyYield);
293-
auto &bodyExit = (kind == LoopKind::For ? stepFrontBlock : condFrontBlock);
294-
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(bodyYield, &bodyExit);
332+
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(bodyYield, &stepBlock);
295333

296334
// Is a for loop: branch from step to condition.
297335
if (kind == LoopKind::For) {
@@ -483,6 +521,11 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
483521
}
484522
};
485523

524+
static bool isLoopYield(mlir::cir::YieldOp &op) {
525+
return op.getKind() == mlir::cir::YieldOpKind::Break ||
526+
op.getKind() == mlir::cir::YieldOpKind::Continue;
527+
}
528+
486529
class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
487530
public:
488531
using mlir::OpConversionPattern<mlir::cir::IfOp>::OpConversionPattern;
@@ -511,7 +554,8 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
511554
rewriter.setInsertionPointToEnd(thenAfterBody);
512555
if (auto thenYieldOp =
513556
dyn_cast<mlir::cir::YieldOp>(thenAfterBody->getTerminator())) {
514-
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
557+
if (!isLoopYield(thenYieldOp)) // lowering of parent loop yields is deferred to loop lowering
558+
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
515559
thenYieldOp, thenYieldOp.getArgs(), continueBlock);
516560
} else if (!dyn_cast<mlir::cir::ReturnOp>(thenAfterBody->getTerminator())) {
517561
llvm_unreachable("what are we terminating with?");
@@ -540,7 +584,8 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
540584
rewriter.setInsertionPointToEnd(elseAfterBody);
541585
if (auto elseYieldOp =
542586
dyn_cast<mlir::cir::YieldOp>(elseAfterBody->getTerminator())) {
543-
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
587+
if (!isLoopYield(elseYieldOp)) // lowering of parent loop yields is deferred to loop lowering
588+
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
544589
elseYieldOp, elseYieldOp.getArgs(), continueBlock);
545590
} else if (!dyn_cast<mlir::cir::ReturnOp>(
546591
elseAfterBody->getTerminator())) {
@@ -1081,6 +1126,8 @@ class CIRSwitchOpLowering
10811126
case mlir::cir::YieldOpKind::Break:
10821127
rewriteYieldOp(rewriter, yieldOp, exitBlock);
10831128
break;
1129+
case mlir::cir::YieldOpKind::Continue: // Contniue is handled only in loop lowering
1130+
break;
10841131
default:
10851132
return op->emitError("invalid yield kind in case statement");
10861133
}

0 commit comments

Comments
 (0)