Skip to content

Commit 27abb6d

Browse files
sitio-coutolanza
authored andcommitted
[CIR][Lowering] Lower structured while loops
Essentially converts a `cir.loop` op of the `while` kind to a CFG. The implementation, however, was only tested with structured loops, so if breaks, continues, or returns are found in the body, it is likely to break. ghstack-source-id: 32d2624 Pull Request resolved: #145
1 parent f63cbf2 commit 27abb6d

File tree

2 files changed

+138
-3
lines changed

2 files changed

+138
-3
lines changed

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "mlir/IR/BuiltinDialect.h"
3535
#include "mlir/IR/BuiltinTypes.h"
3636
#include "mlir/IR/IRMapping.h"
37+
#include "mlir/IR/Operation.h"
3738
#include "mlir/IR/Value.h"
3839
#include "mlir/Pass/Pass.h"
3940
#include "mlir/Pass/PassManager.h"
@@ -111,12 +112,86 @@ class CIRPtrStrideOpLowering
111112
class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
112113
public:
113114
using mlir::OpConversionPattern<mlir::cir::LoopOp>::OpConversionPattern;
115+
using LoopKind = mlir::cir::LoopOpKind;
116+
117+
mlir::LogicalResult
118+
fetchCondRegionYields(mlir::Region &condRegion,
119+
mlir::cir::YieldOp &yieldToBody,
120+
mlir::cir::YieldOp &yieldToCont) const {
121+
for (auto &bb : condRegion) {
122+
if (auto yieldOp = dyn_cast<mlir::cir::YieldOp>(bb.getTerminator())) {
123+
if (!yieldOp.getKind().has_value())
124+
yieldToCont = yieldOp;
125+
else if (yieldOp.getKind() == mlir::cir::YieldOpKind::Continue)
126+
yieldToBody = yieldOp;
127+
else
128+
return mlir::failure();
129+
}
130+
}
131+
132+
// Succeed only if both yields are found.
133+
if (!yieldToBody || !yieldToCont)
134+
return mlir::failure();
135+
return mlir::success();
136+
}
137+
138+
mlir::LogicalResult
139+
rewriteWhileLoop(mlir::cir::LoopOp loopOp, OpAdaptor adaptor,
140+
mlir::ConversionPatternRewriter &rewriter) const {
141+
auto *currentBlock = rewriter.getInsertionBlock();
142+
auto *continueBlock =
143+
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
144+
145+
// Fetch required info from the condition region.
146+
auto &condRegion = loopOp.getCond();
147+
auto &condFrontBlock = condRegion.front();
148+
mlir::cir::YieldOp yieldToBody, yieldToCont;
149+
if (fetchCondRegionYields(condRegion, yieldToBody, yieldToCont).failed())
150+
return loopOp.emitError("failed to fetch yields in cond region");
151+
152+
// Fetch required info from the condition region.
153+
auto &bodyRegion = loopOp.getBody();
154+
auto &bodyFrontBlock = bodyRegion.front();
155+
auto bodyYield =
156+
dyn_cast<mlir::cir::YieldOp>(bodyRegion.back().getTerminator());
157+
assert(bodyYield && "unstructured while loops are NYI");
158+
159+
// Move loop op region contents to current CFG.
160+
rewriter.inlineRegionBefore(condRegion, continueBlock);
161+
rewriter.inlineRegionBefore(bodyRegion, continueBlock);
162+
163+
// Set loop entry point to condition block.
164+
rewriter.setInsertionPointToEnd(currentBlock);
165+
rewriter.create<mlir::cir::BrOp>(loopOp.getLoc(), &condFrontBlock);
166+
167+
// Set loop exit point to continue block.
168+
rewriter.setInsertionPoint(yieldToCont);
169+
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(yieldToCont, continueBlock);
170+
171+
// Branch from condition to body.
172+
rewriter.setInsertionPoint(yieldToBody);
173+
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(yieldToBody, &bodyFrontBlock);
174+
175+
// Branch from body to condition.
176+
rewriter.setInsertionPoint(bodyYield);
177+
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(bodyYield, &condFrontBlock);
178+
179+
// Remove the loop op.
180+
rewriter.eraseOp(loopOp);
181+
return mlir::success();
182+
}
114183

115184
mlir::LogicalResult
116185
matchAndRewrite(mlir::cir::LoopOp loopOp, OpAdaptor adaptor,
117186
mlir::ConversionPatternRewriter &rewriter) const override {
118-
if (loopOp.getKind() != mlir::cir::LoopOpKind::For)
187+
switch (loopOp.getKind()) {
188+
case LoopKind::For:
189+
break;
190+
case LoopKind::While:
191+
return rewriteWhileLoop(loopOp, adaptor, rewriter);
192+
case LoopKind::DoWhile:
119193
llvm_unreachable("NYI");
194+
}
120195

121196
auto loc = loopOp.getLoc();
122197

clang/test/CIR/Lowering/for.cir renamed to clang/test/CIR/Lowering/loop.cir

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ module {
2727
}
2828
cir.return
2929
}
30-
}
3130

3231
// MLIR: module {
3332
// MLIR-NEXT: llvm.func @foo() {
@@ -61,7 +60,6 @@ module {
6160
// MLIR-NEXT: ^bb6: // pred: ^bb3
6261
// MLIR-NEXT: llvm.return
6362
// MLIR-NEXT: }
64-
// MLIR-NEXT: }
6563

6664
// LLVM: define void @foo() {
6765
// LLVM-NEXT: %1 = alloca i32, i64 1, align 4
@@ -95,3 +93,65 @@ module {
9593
// LLVM-NEXT: 15:
9694
// LLVM-NEXT: ret void
9795
// LLVM-NEXT: }
96+
97+
// Test while cir.loop operation lowering.
98+
cir.func @testWhile(%arg0: !s32i) {
99+
%0 = cir.alloca !s32i, cir.ptr <!s32i>, ["i", init] {alignment = 4 : i64}
100+
cir.store %arg0, %0 : !s32i, cir.ptr <!s32i>
101+
cir.scope {
102+
cir.loop while(cond : {
103+
%1 = cir.load %0 : cir.ptr <!s32i>, !s32i
104+
%2 = cir.const(#cir.int<10> : !s32i) : !s32i
105+
%3 = cir.cmp(lt, %1, %2) : !s32i, !s32i
106+
%4 = cir.cast(int_to_bool, %3 : !s32i), !cir.bool
107+
cir.brcond %4 ^bb1, ^bb2
108+
^bb1: // pred: ^bb0
109+
cir.yield continue
110+
^bb2: // pred: ^bb0
111+
cir.yield
112+
}, step : {
113+
cir.yield
114+
}) {
115+
%1 = cir.load %0 : cir.ptr <!s32i>, !s32i
116+
%2 = cir.unary(inc, %1) : !s32i, !s32i
117+
cir.store %2, %0 : !s32i, cir.ptr <!s32i>
118+
cir.yield
119+
}
120+
}
121+
cir.return
122+
}
123+
124+
// MLIR: llvm.func @testWhile(%arg0: i32) {
125+
// MLIR-NEXT: %0 = llvm.mlir.constant(1 : index) : i64
126+
// MLIR-NEXT: %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i64) -> !llvm.ptr
127+
// MLIR-NEXT: llvm.store %arg0, %1 : i32, !llvm.ptr
128+
// MLIR-NEXT: llvm.br ^bb1
129+
// MLIR-NEXT: ^bb1:
130+
// MLIR-NEXT: llvm.br ^bb2
131+
// ============= Condition block =============
132+
// MLIR-NEXT: ^bb2: // 2 preds: ^bb1, ^bb5
133+
// MLIR-NEXT: %2 = llvm.load %1 : !llvm.ptr
134+
// MLIR-NEXT: %3 = llvm.mlir.constant(10 : i32) : i32
135+
// MLIR-NEXT: %4 = llvm.icmp "slt" %2, %3 : i32
136+
// MLIR-NEXT: %5 = llvm.zext %4 : i1 to i32
137+
// MLIR-NEXT: %6 = llvm.mlir.constant(0 : i32) : i32
138+
// MLIR-NEXT: %7 = llvm.icmp "ne" %5, %6 : i32
139+
// MLIR-NEXT: %8 = llvm.zext %7 : i1 to i8
140+
// MLIR-NEXT: %9 = llvm.trunc %8 : i8 to i1
141+
// MLIR-NEXT: llvm.cond_br %9, ^bb3, ^bb4
142+
// MLIR-NEXT: ^bb3: // pred: ^bb2
143+
// MLIR-NEXT: llvm.br ^bb5
144+
// MLIR-NEXT: ^bb4: // pred: ^bb2
145+
// MLIR-NEXT: llvm.br ^bb6
146+
// ============= Body block =============
147+
// MLIR-NEXT: ^bb5: // pred: ^bb3
148+
// MLIR-NEXT: %10 = llvm.load %1 : !llvm.ptr
149+
// MLIR-NEXT: %11 = llvm.mlir.constant(1 : i32) : i32
150+
// MLIR-NEXT: %12 = llvm.add %10, %11 : i32
151+
// MLIR-NEXT: llvm.store %12, %1 : i32, !llvm.ptr
152+
// MLIR-NEXT: llvm.br ^bb2
153+
// ============= Exit block =============
154+
// MLIR-NEXT: ^bb6: // pred: ^bb4
155+
// MLIR-NEXT: llvm.br ^bb7
156+
157+
}

0 commit comments

Comments
 (0)