Skip to content

Commit d6c7fd3

Browse files
gitoleglanza
authored andcommitted
[CIR][Lowering] Fix loop lowering for top-level break/continue (llvm#349)
This PR fixes a couple of corner cases connected with the `YieldOp` lowering in loops. Previously, in llvm#211 we introduced `lowerNestedBreakContinue` but we didn't check that `YieldOp` may belong to the same region, i.e. it is not nested, e.g. ``` while(1) { break; } ``` Hence the error `op already replaced`. Next, we fix `yield` lowering for `ifOp` and `switchOp` but didn't cover `scopeOp`, and the same error occurred. This PR fixes this as well. I added two tests - with no checks actually, just to make sure no more crashes happen. fixes llvm#324
1 parent 27230db commit d6c7fd3

File tree

2 files changed

+100
-5
lines changed

2 files changed

+100
-5
lines changed

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,15 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
408408
lowerNestedBreakContinue(mlir::Region &loopBody, mlir::Block *exitBlock,
409409
mlir::Block *continueBlock,
410410
mlir::ConversionPatternRewriter &rewriter) const {
411+
// top-level yields are lowered in matchAndRewrite
412+
auto isNested = [&](mlir::Operation *op) {
413+
return op->getParentRegion() != &loopBody;
414+
};
411415

412416
auto processBreak = [&](mlir::Operation *op) {
417+
if (!isNested(op))
418+
return mlir::WalkResult::advance();
419+
413420
if (isa<mlir::cir::LoopOp, mlir::cir::SwitchOp>(
414421
*op)) // don't process breaks in nested loops and switches
415422
return mlir::WalkResult::skip();
@@ -421,6 +428,9 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
421428
};
422429

423430
auto processContinue = [&](mlir::Operation *op) {
431+
if (!isNested(op))
432+
return mlir::WalkResult::advance();
433+
424434
if (isa<mlir::cir::LoopOp>(
425435
*op)) // don't process continues in nested loops
426436
return mlir::WalkResult::skip();
@@ -490,7 +500,10 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
490500

491501
// Branch from body to condition or to step on for-loop cases.
492502
rewriter.setInsertionPoint(bodyYield);
493-
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(bodyYield, &stepBlock);
503+
auto bodyYieldDest = bodyYield.getKind() == mlir::cir::YieldOpKind::Break
504+
? continueBlock
505+
: &stepBlock;
506+
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(bodyYield, bodyYieldDest);
494507

495508
// Is a for loop: branch from step to condition.
496509
if (kind == LoopKind::For) {
@@ -822,11 +835,15 @@ class CIRScopeOpLowering
822835
// Stack restore before leaving the body region.
823836
rewriter.setInsertionPointToEnd(afterBody);
824837
auto yieldOp = cast<mlir::cir::YieldOp>(afterBody->getTerminator());
825-
auto branchOp = rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
826-
yieldOp, yieldOp.getArgs(), continueBlock);
827838

828-
// // Insert stack restore before jumping out of the body of the region.
829-
rewriter.setInsertionPoint(branchOp);
839+
if (!isLoopYield(yieldOp)) {
840+
auto branchOp = rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
841+
yieldOp, yieldOp.getArgs(), continueBlock);
842+
843+
// // Insert stack restore before jumping out of the body of the region.
844+
rewriter.setInsertionPoint(branchOp);
845+
}
846+
830847
// TODO(CIR): stackrestore?
831848
// rewriter.create<mlir::LLVM::StackRestoreOp>(loc, stackSaveOp);
832849

clang/test/CIR/Lowering/loop.cir

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,4 +217,82 @@ module {
217217
// MLIR-NEXT: llvm.br ^bb6
218218
// MLIR-NEXT: ^bb6:
219219
// MLIR-NEXT: llvm.return
220+
221+
// test corner case
222+
// while (1) {
223+
// break;
224+
// }
225+
cir.func @whileCornerCase() {
226+
cir.scope {
227+
cir.loop while(cond : {
228+
%0 = cir.const(#cir.int<1> : !s32i) : !s32i
229+
%1 = cir.cast(int_to_bool, %0 : !s32i), !cir.bool
230+
cir.brcond %1 ^bb1, ^bb2
231+
^bb1: // pred: ^bb0
232+
cir.yield continue
233+
^bb2: // pred: ^bb0
234+
cir.yield
235+
}, step : {
236+
cir.yield
237+
}) {
238+
cir.yield break
239+
}
240+
}
241+
cir.return
242+
}
243+
// MLIR: llvm.func @whileCornerCase()
244+
// MLIR: %0 = llvm.mlir.constant(1 : i32) : i32
245+
// MLIR-NEXT: %1 = llvm.mlir.constant(0 : i32) : i32
246+
// MLIR-NEXT: %2 = llvm.icmp "ne" %0, %1 : i32
247+
// MLIR-NEXT: %3 = llvm.zext %2 : i1 to i8
248+
// MLIR-NEXT: %4 = llvm.trunc %3 : i8 to i
249+
// MLIR-NEXT: llvm.cond_br %4, ^bb3, ^bb4
250+
// MLIR-NEXT: ^bb3: // pred: ^bb2
251+
// MLIR-NEXT: llvm.br ^bb5
252+
// MLIR-NEXT: ^bb4: // pred: ^bb2
253+
// MLIR-NEXT: llvm.br ^bb6
254+
// MLIR-NEXT: ^bb5: // pred: ^bb3
255+
// MLIR-NEXT: llvm.br ^bb6
256+
// MLIR-NEXT: ^bb6: // 2 preds: ^bb4, ^bb5
257+
// MLIR-NEXT: llvm.br ^bb7
258+
// MLIR-NEXT: ^bb7: // pred: ^bb6
259+
// MLIR-NEXT: llvm.return
260+
261+
// test corner case - no fails during the lowering
262+
// for (;;) {
263+
// break;
264+
// }
265+
cir.func @forCornerCase() {
266+
cir.scope {
267+
cir.loop for(cond : {
268+
cir.yield continue
269+
}, step : {
270+
cir.yield
271+
}) {
272+
cir.scope {
273+
cir.yield break
274+
}
275+
cir.yield
276+
}
277+
}
278+
cir.return
279+
}
280+
// MLIR: llvm.func @forCornerCase()
281+
// MLIR: llvm.br ^bb1
282+
// MLIR-NEXT: ^bb1: // pred: ^bb0
283+
// MLIR-NEXT: llvm.br ^bb2
284+
// MLIR-NEXT: ^bb2: // 2 preds: ^bb1, ^bb6
285+
// MLIR-NEXT: llvm.br ^bb3
286+
// MLIR-NEXT: ^bb3: // pred: ^bb2
287+
// MLIR-NEXT: llvm.br ^bb4
288+
// MLIR-NEXT: ^bb4: // pred: ^bb3
289+
// MLIR-NEXT: llvm.br ^bb7
290+
// MLIR-NEXT: ^bb5: // no predecessors
291+
// MLIR-NEXT: llvm.br ^bb6
292+
// MLIR-NEXT: ^bb6: // pred: ^bb5
293+
// MLIR-NEXT: llvm.br ^bb2
294+
// MLIR-NEXT: ^bb7: // pred: ^bb4
295+
// MLIR-NEXT: llvm.br ^bb8
296+
// MLIR-NEXT: ^bb8: // pred: ^bb7
297+
// MLIR-NEXT: llvm.return
220298
}

0 commit comments

Comments
 (0)