@@ -238,6 +238,47 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
238238 return mlir::success ();
239239 }
240240
241+ void makeYieldIf (mlir::cir::YieldOpKind kind, mlir::cir::YieldOp &op,
242+ mlir::Block *to,
243+ mlir::ConversionPatternRewriter &rewriter) const {
244+ if (op.getKind () == kind) {
245+ rewriter.setInsertionPoint (op);
246+ rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(op, op.getArgs (), to);
247+ }
248+ }
249+
250+ void
251+ lowerNestedBreakContinue (mlir::Region &loopBody, mlir::Block *exitBlock,
252+ mlir::Block *continueBlock,
253+ mlir::ConversionPatternRewriter &rewriter) const {
254+
255+ auto processBreak = [&](mlir::Operation *op) {
256+ if (isa<mlir::cir::LoopOp, mlir::cir::SwitchOp>(
257+ *op)) // don't process breaks in nested loops and switches
258+ return mlir::WalkResult::skip ();
259+
260+ if (auto yield = dyn_cast<mlir::cir::YieldOp>(*op))
261+ makeYieldIf (mlir::cir::YieldOpKind::Break, yield, exitBlock, rewriter);
262+
263+ return mlir::WalkResult::advance ();
264+ };
265+
266+ auto processContinue = [&](mlir::Operation *op) {
267+ if (isa<mlir::cir::LoopOp>(
268+ *op)) // don't process continues in nested loops
269+ return mlir::WalkResult::skip ();
270+
271+ if (auto yield = dyn_cast<mlir::cir::YieldOp>(*op))
272+ makeYieldIf (mlir::cir::YieldOpKind::Continue, yield, continueBlock,
273+ rewriter);
274+
275+ return mlir::WalkResult::advance ();
276+ };
277+
278+ loopBody.walk <mlir::WalkOrder::PreOrder>(processBreak);
279+ loopBody.walk <mlir::WalkOrder::PreOrder>(processContinue);
280+ }
281+
241282 mlir::LogicalResult
242283 matchAndRewrite (mlir::cir::LoopOp loopOp, OpAdaptor adaptor,
243284 mlir::ConversionPatternRewriter &rewriter) const override {
@@ -265,6 +306,9 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
265306 auto &stepFrontBlock = stepRegion.front ();
266307 auto stepYield =
267308 dyn_cast<mlir::cir::YieldOp>(stepRegion.back ().getTerminator ());
309+ auto &stepBlock = (kind == LoopKind::For ? stepFrontBlock : condFrontBlock);
310+
311+ lowerNestedBreakContinue (bodyRegion, continueBlock, &stepBlock, rewriter);
268312
269313 // Move loop op region contents to current CFG.
270314 rewriter.inlineRegionBefore (condRegion, continueBlock);
@@ -287,8 +331,7 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
287331
288332 // Branch from body to condition or to step on for-loop cases.
289333 rewriter.setInsertionPoint (bodyYield);
290- auto &bodyExit = (kind == LoopKind::For ? stepFrontBlock : condFrontBlock);
291- rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(bodyYield, &bodyExit);
334+ rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(bodyYield, &stepBlock);
292335
293336 // Is a for loop: branch from step to condition.
294337 if (kind == LoopKind::For) {
@@ -480,6 +523,11 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
480523 }
481524};
482525
526+ static bool isLoopYield (mlir::cir::YieldOp &op) {
527+ return op.getKind () == mlir::cir::YieldOpKind::Break ||
528+ op.getKind () == mlir::cir::YieldOpKind::Continue;
529+ }
530+
483531class CIRIfLowering : public mlir ::OpConversionPattern<mlir::cir::IfOp> {
484532public:
485533 using mlir::OpConversionPattern<mlir::cir::IfOp>::OpConversionPattern;
@@ -508,8 +556,10 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
508556 rewriter.setInsertionPointToEnd (thenAfterBody);
509557 if (auto thenYieldOp =
510558 dyn_cast<mlir::cir::YieldOp>(thenAfterBody->getTerminator ())) {
511- rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(
512- thenYieldOp, thenYieldOp.getArgs (), continueBlock);
559+ if (!isLoopYield (thenYieldOp)) // lowering of parent loop yields is
560+ // deferred to loop lowering
561+ rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(
562+ thenYieldOp, thenYieldOp.getArgs (), continueBlock);
513563 } else if (!dyn_cast<mlir::cir::ReturnOp>(thenAfterBody->getTerminator ())) {
514564 llvm_unreachable (" what are we terminating with?" );
515565 }
@@ -537,8 +587,10 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
537587 rewriter.setInsertionPointToEnd (elseAfterBody);
538588 if (auto elseYieldOp =
539589 dyn_cast<mlir::cir::YieldOp>(elseAfterBody->getTerminator ())) {
540- rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(
541- elseYieldOp, elseYieldOp.getArgs (), continueBlock);
590+ if (!isLoopYield (elseYieldOp)) // lowering of parent loop yields is
591+ // deferred to loop lowering
592+ rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(
593+ elseYieldOp, elseYieldOp.getArgs (), continueBlock);
542594 } else if (!dyn_cast<mlir::cir::ReturnOp>(
543595 elseAfterBody->getTerminator ())) {
544596 llvm_unreachable (" what are we terminating with?" );
@@ -1097,6 +1149,9 @@ class CIRSwitchOpLowering
10971149 case mlir::cir::YieldOpKind::Break:
10981150 rewriteYieldOp (rewriter, yieldOp, exitBlock);
10991151 break ;
1152+ case mlir::cir::YieldOpKind::Continue: // Continue is handled only in
1153+ // loop lowering
1154+ break ;
11001155 default :
11011156 return op->emitError (" invalid yield kind in case statement" );
11021157 }
@@ -1676,8 +1731,7 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
16761731 CIRVAStartLowering, CIRVAEndLowering, CIRVACopyLowering,
16771732 CIRVAArgLowering, CIRBrOpLowering, CIRTernaryOpLowering,
16781733 CIRStructElementAddrOpLowering, CIRSwitchOpLowering,
1679- CIRPtrDiffOpLowering>(
1680- converter, patterns.getContext ());
1734+ CIRPtrDiffOpLowering>(converter, patterns.getContext ());
16811735}
16821736
16831737namespace {
0 commit comments