Skip to content

Commit 98222a8

Browse files
gitoleglanza
authored andcommitted
[CIR][Lowering] Fixed break/continue lowering for loops (#211)
This PR fixes lowering for `break/continue` in loops. The idea is to replace `cir.yield break` and `cir.yield continue` with the branch operations to the corresponding blocks. Note, that we need to ignore nesting loops and don't touch `break` in switch operations. Also, `yield` from `if` need to be considered only when it's not the loop `yield` and `continue` in switch is ignored since it's processed in the loops lowering. Fixes #160
1 parent 4c0daac commit 98222a8

File tree

3 files changed

+702
-8
lines changed

3 files changed

+702
-8
lines changed

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,47 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
238238
return mlir::success();
239239
}
240240

241+
void makeYieldIf(mlir::cir::YieldOpKind kind, mlir::cir::YieldOp &op,
242+
mlir::Block *to,
243+
mlir::ConversionPatternRewriter &rewriter) const {
244+
if (op.getKind() == kind) {
245+
rewriter.setInsertionPoint(op);
246+
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(op, op.getArgs(), to);
247+
}
248+
}
249+
250+
void
251+
lowerNestedBreakContinue(mlir::Region &loopBody, mlir::Block *exitBlock,
252+
mlir::Block *continueBlock,
253+
mlir::ConversionPatternRewriter &rewriter) const {
254+
255+
auto processBreak = [&](mlir::Operation *op) {
256+
if (isa<mlir::cir::LoopOp, mlir::cir::SwitchOp>(
257+
*op)) // don't process breaks in nested loops and switches
258+
return mlir::WalkResult::skip();
259+
260+
if (auto yield = dyn_cast<mlir::cir::YieldOp>(*op))
261+
makeYieldIf(mlir::cir::YieldOpKind::Break, yield, exitBlock, rewriter);
262+
263+
return mlir::WalkResult::advance();
264+
};
265+
266+
auto processContinue = [&](mlir::Operation *op) {
267+
if (isa<mlir::cir::LoopOp>(
268+
*op)) // don't process continues in nested loops
269+
return mlir::WalkResult::skip();
270+
271+
if (auto yield = dyn_cast<mlir::cir::YieldOp>(*op))
272+
makeYieldIf(mlir::cir::YieldOpKind::Continue, yield, continueBlock,
273+
rewriter);
274+
275+
return mlir::WalkResult::advance();
276+
};
277+
278+
loopBody.walk<mlir::WalkOrder::PreOrder>(processBreak);
279+
loopBody.walk<mlir::WalkOrder::PreOrder>(processContinue);
280+
}
281+
241282
mlir::LogicalResult
242283
matchAndRewrite(mlir::cir::LoopOp loopOp, OpAdaptor adaptor,
243284
mlir::ConversionPatternRewriter &rewriter) const override {
@@ -265,6 +306,9 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
265306
auto &stepFrontBlock = stepRegion.front();
266307
auto stepYield =
267308
dyn_cast<mlir::cir::YieldOp>(stepRegion.back().getTerminator());
309+
auto &stepBlock = (kind == LoopKind::For ? stepFrontBlock : condFrontBlock);
310+
311+
lowerNestedBreakContinue(bodyRegion, continueBlock, &stepBlock, rewriter);
268312

269313
// Move loop op region contents to current CFG.
270314
rewriter.inlineRegionBefore(condRegion, continueBlock);
@@ -287,8 +331,7 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
287331

288332
// Branch from body to condition or to step on for-loop cases.
289333
rewriter.setInsertionPoint(bodyYield);
290-
auto &bodyExit = (kind == LoopKind::For ? stepFrontBlock : condFrontBlock);
291-
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(bodyYield, &bodyExit);
334+
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(bodyYield, &stepBlock);
292335

293336
// Is a for loop: branch from step to condition.
294337
if (kind == LoopKind::For) {
@@ -480,6 +523,11 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
480523
}
481524
};
482525

526+
static bool isLoopYield(mlir::cir::YieldOp &op) {
527+
return op.getKind() == mlir::cir::YieldOpKind::Break ||
528+
op.getKind() == mlir::cir::YieldOpKind::Continue;
529+
}
530+
483531
class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
484532
public:
485533
using mlir::OpConversionPattern<mlir::cir::IfOp>::OpConversionPattern;
@@ -508,8 +556,10 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
508556
rewriter.setInsertionPointToEnd(thenAfterBody);
509557
if (auto thenYieldOp =
510558
dyn_cast<mlir::cir::YieldOp>(thenAfterBody->getTerminator())) {
511-
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
512-
thenYieldOp, thenYieldOp.getArgs(), continueBlock);
559+
if (!isLoopYield(thenYieldOp)) // lowering of parent loop yields is
560+
// deferred to loop lowering
561+
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
562+
thenYieldOp, thenYieldOp.getArgs(), continueBlock);
513563
} else if (!dyn_cast<mlir::cir::ReturnOp>(thenAfterBody->getTerminator())) {
514564
llvm_unreachable("what are we terminating with?");
515565
}
@@ -537,8 +587,10 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
537587
rewriter.setInsertionPointToEnd(elseAfterBody);
538588
if (auto elseYieldOp =
539589
dyn_cast<mlir::cir::YieldOp>(elseAfterBody->getTerminator())) {
540-
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
541-
elseYieldOp, elseYieldOp.getArgs(), continueBlock);
590+
if (!isLoopYield(elseYieldOp)) // lowering of parent loop yields is
591+
// deferred to loop lowering
592+
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
593+
elseYieldOp, elseYieldOp.getArgs(), continueBlock);
542594
} else if (!dyn_cast<mlir::cir::ReturnOp>(
543595
elseAfterBody->getTerminator())) {
544596
llvm_unreachable("what are we terminating with?");
@@ -1097,6 +1149,9 @@ class CIRSwitchOpLowering
10971149
case mlir::cir::YieldOpKind::Break:
10981150
rewriteYieldOp(rewriter, yieldOp, exitBlock);
10991151
break;
1152+
case mlir::cir::YieldOpKind::Continue: // Continue is handled only in
1153+
// loop lowering
1154+
break;
11001155
default:
11011156
return op->emitError("invalid yield kind in case statement");
11021157
}
@@ -1676,8 +1731,7 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
16761731
CIRVAStartLowering, CIRVAEndLowering, CIRVACopyLowering,
16771732
CIRVAArgLowering, CIRBrOpLowering, CIRTernaryOpLowering,
16781733
CIRStructElementAddrOpLowering, CIRSwitchOpLowering,
1679-
CIRPtrDiffOpLowering>(
1680-
converter, patterns.getContext());
1734+
CIRPtrDiffOpLowering>(converter, patterns.getContext());
16811735
}
16821736

16831737
namespace {

0 commit comments

Comments
 (0)