Skip to content

Commit eb6c419

Browse files
[mlir][CF] Split cf-to-llvm from func-to-llvm (#120580)
Do not run `cf-to-llvm` as part of `func-to-llvm`. This commit fixes #70982. This commit changes the way how `func.func` ops are lowered to LLVM. Previously, the signature of the entire region (i.e., entry block and all other blocks in the `func.func` op) was converted as part of the `func.func` lowering pattern. Now, only the entry block is converted. The remaining block signatures are converted together with `cf.br` and `cf.cond_br` as part of `cf-to-llvm`. All unstructured control flow is not converted as part of a single pass (`cf-to-llvm`). `func-to-llvm` no longer deals with unstructured control flow. Also add more test cases for control flow dialect ops. Note: This PR is in preparation of #120431, which adds an additional GPU-specific lowering for `cf.assert`. This was a problem because `cf.assert` used to be converted as part of `func-to-llvm`. Note for LLVM integration: If you see failures, add `-convert-cf-to-llvm` to your pass pipeline.
1 parent cf7b3f8 commit eb6c419

File tree

92 files changed

+410
-226
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

92 files changed

+410
-226
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3287,10 +3287,40 @@ struct SelectCaseOpConversion : public fir::FIROpConversion<fir::SelectCaseOp> {
32873287
}
32883288
};
32893289

3290+
/// Helper function for converting select ops. This function converts the
3291+
/// signature of the given block. If the new block signature is different from
3292+
/// `expectedTypes`, returns "failure".
3293+
static llvm::FailureOr<mlir::Block *>
3294+
getConvertedBlock(mlir::ConversionPatternRewriter &rewriter,
3295+
const mlir::TypeConverter *converter,
3296+
mlir::Operation *branchOp, mlir::Block *block,
3297+
mlir::TypeRange expectedTypes) {
3298+
assert(converter && "expected non-null type converter");
3299+
assert(!block->isEntryBlock() && "entry blocks have no predecessors");
3300+
3301+
// There is nothing to do if the types already match.
3302+
if (block->getArgumentTypes() == expectedTypes)
3303+
return block;
3304+
3305+
// Compute the new block argument types and convert the block.
3306+
std::optional<mlir::TypeConverter::SignatureConversion> conversion =
3307+
converter->convertBlockSignature(block);
3308+
if (!conversion)
3309+
return rewriter.notifyMatchFailure(branchOp,
3310+
"could not compute block signature");
3311+
if (expectedTypes != conversion->getConvertedTypes())
3312+
return rewriter.notifyMatchFailure(
3313+
branchOp,
3314+
"mismatch between adaptor operand types and computed block signature");
3315+
return rewriter.applySignatureConversion(block, *conversion, converter);
3316+
}
3317+
32903318
template <typename OP>
3291-
static void selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering,
3292-
OP select, typename OP::Adaptor adaptor,
3293-
mlir::ConversionPatternRewriter &rewriter) {
3319+
static llvm::LogicalResult
3320+
selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering, OP select,
3321+
typename OP::Adaptor adaptor,
3322+
mlir::ConversionPatternRewriter &rewriter,
3323+
const mlir::TypeConverter *converter) {
32943324
unsigned conds = select.getNumConditions();
32953325
auto cases = select.getCases().getValue();
32963326
mlir::Value selector = adaptor.getSelector();
@@ -3308,15 +3338,24 @@ static void selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering,
33083338
auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
33093339
const mlir::Attribute &attr = cases[t];
33103340
if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
3311-
destinations.push_back(dest);
33123341
destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{});
3342+
auto convertedBlock =
3343+
getConvertedBlock(rewriter, converter, select, dest,
3344+
mlir::TypeRange(destinationsOperands.back()));
3345+
if (mlir::failed(convertedBlock))
3346+
return mlir::failure();
3347+
destinations.push_back(*convertedBlock);
33133348
caseValues.push_back(intAttr.getInt());
33143349
continue;
33153350
}
33163351
assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
33173352
assert((t + 1 == conds) && "unit must be last");
3318-
defaultDestination = dest;
33193353
defaultOperands = destOps ? *destOps : mlir::ValueRange{};
3354+
auto convertedBlock = getConvertedBlock(rewriter, converter, select, dest,
3355+
mlir::TypeRange(defaultOperands));
3356+
if (mlir::failed(convertedBlock))
3357+
return mlir::failure();
3358+
defaultDestination = *convertedBlock;
33203359
}
33213360

33223361
// LLVM::SwitchOp takes a i32 type for the selector.
@@ -3332,6 +3371,7 @@ static void selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering,
33323371
/*caseDestinations=*/destinations,
33333372
/*caseOperands=*/destinationsOperands,
33343373
/*branchWeights=*/llvm::ArrayRef<std::int32_t>());
3374+
return mlir::success();
33353375
}
33363376

33373377
/// conversion of fir::SelectOp to an if-then-else ladder
@@ -3341,8 +3381,8 @@ struct SelectOpConversion : public fir::FIROpConversion<fir::SelectOp> {
33413381
llvm::LogicalResult
33423382
matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor,
33433383
mlir::ConversionPatternRewriter &rewriter) const override {
3344-
selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor, rewriter);
3345-
return mlir::success();
3384+
return selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor,
3385+
rewriter, getTypeConverter());
33463386
}
33473387
};
33483388

@@ -3353,8 +3393,8 @@ struct SelectRankOpConversion : public fir::FIROpConversion<fir::SelectRankOp> {
33533393
llvm::LogicalResult
33543394
matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor,
33553395
mlir::ConversionPatternRewriter &rewriter) const override {
3356-
selectMatchAndRewrite<fir::SelectRankOp>(lowerTy(), op, adaptor, rewriter);
3357-
return mlir::success();
3396+
return selectMatchAndRewrite<fir::SelectRankOp>(
3397+
lowerTy(), op, adaptor, rewriter, getTypeConverter());
33583398
}
33593399
};
33603400

mlir/include/mlir/Conversion/Passes.td

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -460,10 +460,6 @@ def ConvertFuncToLLVMPass : Pass<"convert-func-to-llvm", "ModuleOp"> {
460460
1 value is returned, packed into an LLVM IR struct type. Function calls and
461461
returns are updated accordingly. Block argument types are updated to use
462462
LLVM IR types.
463-
464-
Note that until https://github.com/llvm/llvm-project/issues/70982 is resolved,
465-
this pass includes patterns that lower `arith` and `cf` to LLVM. This is legacy
466-
code due to when they were all converted in the same pass.
467463
}];
468464
let dependentDialects = ["LLVM::LLVMDialect"];
469465
let options = [

mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp

Lines changed: 86 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -94,106 +94,117 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
9494
bool abortOnFailedAssert = true;
9595
};
9696

97-
/// The cf->LLVM lowerings for branching ops require that the blocks they jump
98-
/// to first have updated types which should be handled by a pattern operating
99-
/// on the parent op.
100-
static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter,
101-
ValueRange operands,
102-
ValueRange blockArgs, Location loc,
103-
llvm::StringRef messagePrefix) {
104-
for (const auto &idxAndTypes :
105-
llvm::enumerate(llvm::zip(blockArgs, operands))) {
106-
int64_t i = idxAndTypes.index();
107-
Value argValue =
108-
rewriter.getRemappedValue(std::get<0>(idxAndTypes.value()));
109-
Type operandType = std::get<1>(idxAndTypes.value()).getType();
110-
// In the case of an invalid jump, the block argument will have been
111-
// remapped to an UnrealizedConversionCast. In the case of a valid jump,
112-
// there might still be a no-op conversion cast with both types being equal.
113-
// Consider both of these details to see if the jump would be invalid.
114-
if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>(
115-
argValue.getDefiningOp())) {
116-
if (op.getOperandTypes().front() != operandType) {
117-
return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) {
118-
diag << messagePrefix;
119-
diag << "mismatched types from operand # " << i << " ";
120-
diag << operandType;
121-
diag << " not compatible with destination block argument type ";
122-
diag << op.getOperandTypes().front();
123-
diag << " which should be converted with the parent op.";
124-
});
125-
}
126-
}
127-
}
128-
return success();
97+
/// Helper function for converting branch ops. This function converts the
98+
/// signature of the given block. If the new block signature is different from
99+
/// `expectedTypes`, returns "failure".
100+
static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
101+
const TypeConverter *converter,
102+
Operation *branchOp, Block *block,
103+
TypeRange expectedTypes) {
104+
assert(converter && "expected non-null type converter");
105+
assert(!block->isEntryBlock() && "entry blocks have no predecessors");
106+
107+
// There is nothing to do if the types already match.
108+
if (block->getArgumentTypes() == expectedTypes)
109+
return block;
110+
111+
// Compute the new block argument types and convert the block.
112+
std::optional<TypeConverter::SignatureConversion> conversion =
113+
converter->convertBlockSignature(block);
114+
if (!conversion)
115+
return rewriter.notifyMatchFailure(branchOp,
116+
"could not compute block signature");
117+
if (expectedTypes != conversion->getConvertedTypes())
118+
return rewriter.notifyMatchFailure(
119+
branchOp,
120+
"mismatch between adaptor operand types and computed block signature");
121+
return rewriter.applySignatureConversion(block, *conversion, converter);
129122
}
130123

131-
/// Ensure that all block types were updated and then create an LLVM::BrOp
124+
/// Convert the destination block signature (if necessary) and lower the branch
125+
/// op to llvm.br.
132126
struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
133127
using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
134128

135129
LogicalResult
136130
matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
137131
ConversionPatternRewriter &rewriter) const override {
138-
if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(),
139-
op.getSuccessor()->getArguments(),
140-
op.getLoc(),
141-
/*messagePrefix=*/"")))
132+
FailureOr<Block *> convertedBlock =
133+
getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
134+
TypeRange(adaptor.getOperands()));
135+
if (failed(convertedBlock))
142136
return failure();
143-
144-
rewriter.replaceOpWithNewOp<LLVM::BrOp>(
145-
op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
137+
Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
138+
op, adaptor.getOperands(), *convertedBlock);
139+
// TODO: We should not just forward all attributes like that. But there are
140+
// existing Flang tests that depend on this behavior.
141+
newOp->setAttrs(op->getAttrDictionary());
146142
return success();
147143
}
148144
};
149145

150-
/// Ensure that all block types were updated and then create an LLVM::CondBrOp
146+
/// Convert the destination block signatures (if necessary) and lower the
147+
/// branch op to llvm.cond_br.
151148
struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
152149
using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
153150

154151
LogicalResult
155152
matchAndRewrite(cf::CondBranchOp op,
156153
typename cf::CondBranchOp::Adaptor adaptor,
157154
ConversionPatternRewriter &rewriter) const override {
158-
if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(),
159-
op.getFalseDest()->getArguments(),
160-
op.getLoc(), "in false case branch ")))
155+
FailureOr<Block *> convertedTrueBlock =
156+
getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
157+
TypeRange(adaptor.getTrueDestOperands()));
158+
if (failed(convertedTrueBlock))
161159
return failure();
162-
if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(),
163-
op.getTrueDest()->getArguments(),
164-
op.getLoc(), "in true case branch ")))
160+
FailureOr<Block *> convertedFalseBlock =
161+
getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
162+
TypeRange(adaptor.getFalseDestOperands()));
163+
if (failed(convertedFalseBlock))
165164
return failure();
166-
167-
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
168-
op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
165+
Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
166+
op, adaptor.getCondition(), *convertedTrueBlock,
167+
adaptor.getTrueDestOperands(), *convertedFalseBlock,
168+
adaptor.getFalseDestOperands());
169+
// TODO: We should not just forward all attributes like that. But there are
170+
// existing Flang tests that depend on this behavior.
171+
newOp->setAttrs(op->getAttrDictionary());
169172
return success();
170173
}
171174
};
172175

173-
/// Ensure that all block types were updated and then create an LLVM::SwitchOp
176+
/// Convert the destination block signatures (if necessary) and lower the
177+
/// switch op to llvm.switch.
174178
struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
175179
using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
176180

177181
LogicalResult
178182
matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
179183
ConversionPatternRewriter &rewriter) const override {
180-
if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(),
181-
op.getDefaultDestination()->getArguments(),
182-
op.getLoc(), "in switch default case ")))
184+
// Get or convert default block.
185+
FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
186+
rewriter, getTypeConverter(), op, op.getDefaultDestination(),
187+
TypeRange(adaptor.getDefaultOperands()));
188+
if (failed(convertedDefaultBlock))
183189
return failure();
184190

185-
for (const auto &i : llvm::enumerate(
186-
llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) {
187-
if (failed(verifyMatchingValues(
188-
rewriter, std::get<0>(i.value()),
189-
std::get<1>(i.value())->getArguments(), op.getLoc(),
190-
"in switch case " + std::to_string(i.index()) + " "))) {
191+
// Get or convert all case blocks.
192+
SmallVector<Block *> caseDestinations;
193+
SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
194+
for (auto it : llvm::enumerate(op.getCaseDestinations())) {
195+
Block *b = it.value();
196+
FailureOr<Block *> convertedBlock =
197+
getConvertedBlock(rewriter, getTypeConverter(), op, b,
198+
TypeRange(caseOperands[it.index()]));
199+
if (failed(convertedBlock))
191200
return failure();
192-
}
201+
caseDestinations.push_back(*convertedBlock);
193202
}
194203

195204
rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
196-
op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
205+
op, adaptor.getFlag(), *convertedDefaultBlock,
206+
adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
207+
caseDestinations, caseOperands);
197208
return success();
198209
}
199210
};
@@ -230,14 +241,22 @@ struct ConvertControlFlowToLLVM
230241

231242
/// Run the dialect converter on the module.
232243
void runOnOperation() override {
233-
LLVMConversionTarget target(getContext());
234-
RewritePatternSet patterns(&getContext());
235-
236-
LowerToLLVMOptions options(&getContext());
244+
MLIRContext *ctx = &getContext();
245+
LLVMConversionTarget target(*ctx);
246+
// This pass lowers only CF dialect ops, but it also modifies block
247+
// signatures inside other ops. These ops should be treated as legal. They
248+
// are lowered by other passes.
249+
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
250+
return op->getDialect() !=
251+
ctx->getLoadedDialect<cf::ControlFlowDialect>();
252+
});
253+
254+
LowerToLLVMOptions options(ctx);
237255
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
238256
options.overrideIndexBitwidth(indexBitwidth);
239257

240-
LLVMTypeConverter converter(&getContext(), options);
258+
LLVMTypeConverter converter(ctx, options);
259+
RewritePatternSet patterns(ctx);
241260
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
242261

243262
if (failed(applyPartialConversion(getOperation(), target,

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -432,11 +432,11 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
432432

433433
rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(),
434434
newFuncOp.end());
435-
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), converter,
436-
&result))) {
437-
return rewriter.notifyMatchFailure(funcOp,
438-
"region types conversion failed");
439-
}
435+
// Convert just the entry block. The remaining unstructured control flow is
436+
// converted by ControlFlowToLLVM.
437+
if (!newFuncOp.getBody().empty())
438+
rewriter.applySignatureConversion(&newFuncOp.getBody().front(), result,
439+
&converter);
440440

441441
// Fix the type mismatch between the materialized `llvm.ptr` and the expected
442442
// pointee type in the function body when converting `llvm.byval`/`llvm.byref`
@@ -785,10 +785,6 @@ struct ConvertFuncToLLVMPass
785785
RewritePatternSet patterns(&getContext());
786786
populateFuncToLLVMConversionPatterns(typeConverter, patterns, symbolTable);
787787

788-
// TODO(https://github.com/llvm/llvm-project/issues/70982): Remove these in
789-
// favor of their dedicated conversion passes.
790-
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
791-
792788
LLVMConversionTarget target(getContext());
793789
if (failed(applyPartialConversion(m, target, std::move(patterns))))
794790
signalPassFailure();

mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Dialect/SparseTensor/Pipelines/Passes.h"
1010

1111
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
12+
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
1213
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
1314
#include "mlir/Conversion/Passes.h"
1415
#include "mlir/Dialect/Arith/Transforms/Passes.h"
@@ -91,6 +92,7 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm,
9192
createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
9293
pm.addPass(createConvertFuncToLLVMPass());
9394
pm.addPass(createArithToLLVMConversionPass());
95+
pm.addPass(createConvertControlFlowToLLVMPass());
9496

9597
// Finalize GPU code generation.
9698
if (gpuCodegen) {

0 commit comments

Comments
 (0)