@@ -240,6 +240,47 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
240240 return mlir::success ();
241241 }
242242
243+ void makeYieldIf (mlir::cir::YieldOpKind kind, mlir::cir::YieldOp &op,
244+ mlir::Block *to,
245+ mlir::ConversionPatternRewriter &rewriter) const {
246+ if (op.getKind () == kind) {
247+ rewriter.setInsertionPoint (op);
248+ rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(op, op.getArgs (), to);
249+ }
250+ }
251+
252+ void
253+ lowerNestedBreakContinue (mlir::Region &loopBody, mlir::Block *exitBlock,
254+ mlir::Block *continueBlock,
255+ mlir::ConversionPatternRewriter &rewriter) const {
256+
257+ auto processBreak = [&](mlir::Operation *op) {
258+ if (isa<mlir::cir::LoopOp, mlir::cir::SwitchOp>(
259+ *op)) // don't process breaks in nested loops and switches
260+ return mlir::WalkResult::skip ();
261+
262+ if (auto yield = dyn_cast<mlir::cir::YieldOp>(*op))
263+ makeYieldIf (mlir::cir::YieldOpKind::Break, yield, exitBlock, rewriter);
264+
265+ return mlir::WalkResult::advance ();
266+ };
267+
268+ auto processContinue = [&](mlir::Operation *op) {
269+ if (isa<mlir::cir::LoopOp>(
270+ *op)) // don't process continues in nested loops
271+ return mlir::WalkResult::skip ();
272+
273+ if (auto yield = dyn_cast<mlir::cir::YieldOp>(*op))
274+ makeYieldIf (mlir::cir::YieldOpKind::Continue, yield, continueBlock,
275+ rewriter);
276+
277+ return mlir::WalkResult::advance ();
278+ };
279+
280+ loopBody.walk <mlir::WalkOrder::PreOrder>(processBreak);
281+ loopBody.walk <mlir::WalkOrder::PreOrder>(processContinue);
282+ }
283+
243284 mlir::LogicalResult
244285 matchAndRewrite (mlir::cir::LoopOp loopOp, OpAdaptor adaptor,
245286 mlir::ConversionPatternRewriter &rewriter) const override {
@@ -267,6 +308,9 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
267308 auto &stepFrontBlock = stepRegion.front ();
268309 auto stepYield =
269310 dyn_cast<mlir::cir::YieldOp>(stepRegion.back ().getTerminator ());
311+ auto &stepBlock = (kind == LoopKind::For ? stepFrontBlock : condFrontBlock);
312+
313+ lowerNestedBreakContinue (bodyRegion, continueBlock, &stepBlock, rewriter);
270314
271315 // Move loop op region contents to current CFG.
272316 rewriter.inlineRegionBefore (condRegion, continueBlock);
@@ -289,8 +333,7 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
289333
290334 // Branch from body to condition or to step on for-loop cases.
291335 rewriter.setInsertionPoint (bodyYield);
292- auto &bodyExit = (kind == LoopKind::For ? stepFrontBlock : condFrontBlock);
293- rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(bodyYield, &bodyExit);
336+ rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(bodyYield, &stepBlock);
294337
295338 // Is a for loop: branch from step to condition.
296339 if (kind == LoopKind::For) {
@@ -488,6 +531,11 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
488531 }
489532};
490533
534+ static bool isLoopYield (mlir::cir::YieldOp &op) {
535+ return op.getKind () == mlir::cir::YieldOpKind::Break ||
536+ op.getKind () == mlir::cir::YieldOpKind::Continue;
537+ }
538+
491539class CIRIfLowering : public mlir ::OpConversionPattern<mlir::cir::IfOp> {
492540public:
493541 using mlir::OpConversionPattern<mlir::cir::IfOp>::OpConversionPattern;
@@ -516,8 +564,10 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
516564 rewriter.setInsertionPointToEnd (thenAfterBody);
517565 if (auto thenYieldOp =
518566 dyn_cast<mlir::cir::YieldOp>(thenAfterBody->getTerminator ())) {
519- rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(
520- thenYieldOp, thenYieldOp.getArgs (), continueBlock);
567+ if (!isLoopYield (thenYieldOp)) // lowering of parent loop yields is
568+ // deferred to loop lowering
569+ rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(
570+ thenYieldOp, thenYieldOp.getArgs (), continueBlock);
521571 } else if (!dyn_cast<mlir::cir::ReturnOp>(thenAfterBody->getTerminator ())) {
522572 llvm_unreachable (" what are we terminating with?" );
523573 }
@@ -545,8 +595,10 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
545595 rewriter.setInsertionPointToEnd (elseAfterBody);
546596 if (auto elseYieldOp =
547597 dyn_cast<mlir::cir::YieldOp>(elseAfterBody->getTerminator ())) {
548- rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(
549- elseYieldOp, elseYieldOp.getArgs (), continueBlock);
598+ if (!isLoopYield (elseYieldOp)) // lowering of parent loop yields is
599+ // deferred to loop lowering
600+ rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(
601+ elseYieldOp, elseYieldOp.getArgs (), continueBlock);
550602 } else if (!dyn_cast<mlir::cir::ReturnOp>(
551603 elseAfterBody->getTerminator ())) {
552604 llvm_unreachable (" what are we terminating with?" );
@@ -1109,6 +1161,9 @@ class CIRSwitchOpLowering
11091161 case mlir::cir::YieldOpKind::Break:
11101162 rewriteYieldOp (rewriter, yieldOp, exitBlock);
11111163 break ;
1164+ case mlir::cir::YieldOpKind::Continue: // Continue is handled only in
1165+ // loop lowering
1166+ break ;
11121167 default :
11131168 return op->emitError (" invalid yield kind in case statement" );
11141169 }
@@ -1692,8 +1747,7 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
16921747 CIRVAStartLowering, CIRVAEndLowering, CIRVACopyLowering,
16931748 CIRVAArgLowering, CIRBrOpLowering, CIRTernaryOpLowering,
16941749 CIRStructElementAddrOpLowering, CIRSwitchOpLowering,
1695- CIRPtrDiffOpLowering>(
1696- converter, patterns.getContext ());
1750+ CIRPtrDiffOpLowering>(converter, patterns.getContext ());
16971751}
16981752
16991753namespace {
0 commit comments