@@ -241,6 +241,42 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
241241 return mlir::success ();
242242 }
243243
244+ void makeYieldIf (mlir::cir::YieldOpKind kind, mlir::cir::YieldOp &op, mlir::Block *to,
245+ mlir::ConversionPatternRewriter &rewriter) const {
246+ if (op.getKind () == kind) {
247+ rewriter.setInsertionPoint (op);
248+ rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(op, op.getArgs (), to);
249+ }
250+ }
251+
252+ void lowerNestedBreakContinue (mlir::Region &loopBody, mlir::Block *exitBlock,
253+ mlir::Block *continueBlock,
254+ mlir::ConversionPatternRewriter &rewriter) const {
255+
256+ auto processBreak = [&](mlir::Operation *op) {
257+ if (isa<mlir::cir::LoopOp, mlir::cir::SwitchOp>(*op)) // don't process breaks in nested loops and switches
258+ return mlir::WalkResult::skip ();
259+
260+ if (auto yield = dyn_cast<mlir::cir::YieldOp>(*op))
261+ makeYieldIf (mlir::cir::YieldOpKind::Break, yield, exitBlock, rewriter);
262+
263+ return mlir::WalkResult::advance ();
264+ };
265+
266+ auto processContinue = [&](mlir::Operation *op) {
267+ if (isa<mlir::cir::LoopOp>(*op)) // don't process continues in nested loops
268+ return mlir::WalkResult::skip ();
269+
270+ if (auto yield = dyn_cast<mlir::cir::YieldOp>(*op))
271+ makeYieldIf (mlir::cir::YieldOpKind::Continue, yield, continueBlock, rewriter);
272+
273+ return mlir::WalkResult::advance ();
274+ };
275+
276+ loopBody.walk <mlir::WalkOrder::PreOrder>(processBreak);
277+ loopBody.walk <mlir::WalkOrder::PreOrder>(processContinue);
278+ }
279+
244280 mlir::LogicalResult
245281 matchAndRewrite (mlir::cir::LoopOp loopOp, OpAdaptor adaptor,
246282 mlir::ConversionPatternRewriter &rewriter) const override {
@@ -268,6 +304,9 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
268304 auto &stepFrontBlock = stepRegion.front ();
269305 auto stepYield =
270306 dyn_cast<mlir::cir::YieldOp>(stepRegion.back ().getTerminator ());
307+ auto &stepBlock = (kind == LoopKind::For ? stepFrontBlock : condFrontBlock);
308+
309+ lowerNestedBreakContinue (bodyRegion, continueBlock, &stepBlock, rewriter);
271310
272311 // Move loop op region contents to current CFG.
273312 rewriter.inlineRegionBefore (condRegion, continueBlock);
@@ -290,8 +329,7 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
290329
291330 // Branch from body to condition or to step on for-loop cases.
292331 rewriter.setInsertionPoint (bodyYield);
293- auto &bodyExit = (kind == LoopKind::For ? stepFrontBlock : condFrontBlock);
294- rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(bodyYield, &bodyExit);
332+ rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(bodyYield, &stepBlock);
295333
296334 // Is a for loop: branch from step to condition.
297335 if (kind == LoopKind::For) {
@@ -483,6 +521,11 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
483521 }
484522};
485523
524+ static bool isLoopYield (mlir::cir::YieldOp &op) {
525+ return op.getKind () == mlir::cir::YieldOpKind::Break ||
526+ op.getKind () == mlir::cir::YieldOpKind::Continue;
527+ }
528+
486529class CIRIfLowering : public mlir ::OpConversionPattern<mlir::cir::IfOp> {
487530public:
488531 using mlir::OpConversionPattern<mlir::cir::IfOp>::OpConversionPattern;
@@ -511,7 +554,8 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
511554 rewriter.setInsertionPointToEnd (thenAfterBody);
512555 if (auto thenYieldOp =
513556 dyn_cast<mlir::cir::YieldOp>(thenAfterBody->getTerminator ())) {
514- rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(
557+ if (!isLoopYield (thenYieldOp)) // lowering of parent loop yields is deferred to loop lowering
558+ rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(
515559 thenYieldOp, thenYieldOp.getArgs (), continueBlock);
516560 } else if (!dyn_cast<mlir::cir::ReturnOp>(thenAfterBody->getTerminator ())) {
517561 llvm_unreachable (" what are we terminating with?" );
@@ -540,7 +584,8 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
540584 rewriter.setInsertionPointToEnd (elseAfterBody);
541585 if (auto elseYieldOp =
542586 dyn_cast<mlir::cir::YieldOp>(elseAfterBody->getTerminator ())) {
543- rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(
587+ if (!isLoopYield (elseYieldOp)) // lowering of parent loop yields is deferred to loop lowering
588+ rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(
544589 elseYieldOp, elseYieldOp.getArgs (), continueBlock);
545590 } else if (!dyn_cast<mlir::cir::ReturnOp>(
546591 elseAfterBody->getTerminator ())) {
@@ -1081,6 +1126,8 @@ class CIRSwitchOpLowering
10811126 case mlir::cir::YieldOpKind::Break:
10821127 rewriteYieldOp (rewriter, yieldOp, exitBlock);
10831128 break ;
1129+ case mlir::cir::YieldOpKind::Continue: // Contniue is handled only in loop lowering
1130+ break ;
10841131 default :
10851132 return op->emitError (" invalid yield kind in case statement" );
10861133 }
0 commit comments