@@ -94,106 +94,117 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
94
94
bool abortOnFailedAssert = true ;
95
95
};
96
96
97
- // / The cf->LLVM lowerings for branching ops require that the blocks they jump
98
- // / to first have updated types which should be handled by a pattern operating
99
- // / on the parent op.
100
- static LogicalResult verifyMatchingValues (ConversionPatternRewriter &rewriter,
101
- ValueRange operands,
102
- ValueRange blockArgs, Location loc,
103
- llvm::StringRef messagePrefix) {
104
- for (const auto &idxAndTypes :
105
- llvm::enumerate (llvm::zip (blockArgs, operands))) {
106
- int64_t i = idxAndTypes.index ();
107
- Value argValue =
108
- rewriter.getRemappedValue (std::get<0 >(idxAndTypes.value ()));
109
- Type operandType = std::get<1 >(idxAndTypes.value ()).getType ();
110
- // In the case of an invalid jump, the block argument will have been
111
- // remapped to an UnrealizedConversionCast. In the case of a valid jump,
112
- // there might still be a no-op conversion cast with both types being equal.
113
- // Consider both of these details to see if the jump would be invalid.
114
- if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>(
115
- argValue.getDefiningOp ())) {
116
- if (op.getOperandTypes ().front () != operandType) {
117
- return rewriter.notifyMatchFailure (loc, [&](Diagnostic &diag) {
118
- diag << messagePrefix;
119
- diag << " mismatched types from operand # " << i << " " ;
120
- diag << operandType;
121
- diag << " not compatible with destination block argument type " ;
122
- diag << op.getOperandTypes ().front ();
123
- diag << " which should be converted with the parent op." ;
124
- });
125
- }
126
- }
127
- }
128
- return success ();
97
+ // / Helper function for converting branch ops. This function converts the
98
+ // / signature of the given block. If the new block signature is different from
99
+ // / `expectedTypes`, returns "failure".
100
+ static FailureOr<Block *> getConvertedBlock (ConversionPatternRewriter &rewriter,
101
+ const TypeConverter *converter,
102
+ Operation *branchOp, Block *block,
103
+ TypeRange expectedTypes) {
104
+ assert (converter && " expected non-null type converter" );
105
+ assert (!block->isEntryBlock () && " entry blocks have no predecessors" );
106
+
107
+ // There is nothing to do if the types already match.
108
+ if (block->getArgumentTypes () == expectedTypes)
109
+ return block;
110
+
111
+ // Compute the new block argument types and convert the block.
112
+ std::optional<TypeConverter::SignatureConversion> conversion =
113
+ converter->convertBlockSignature (block);
114
+ if (!conversion)
115
+ return rewriter.notifyMatchFailure (branchOp,
116
+ " could not compute block signature" );
117
+ if (expectedTypes != conversion->getConvertedTypes ())
118
+ return rewriter.notifyMatchFailure (
119
+ branchOp,
120
+ " mismatch between adaptor operand types and computed block signature" );
121
+ return rewriter.applySignatureConversion (block, *conversion, converter);
129
122
}
130
123
131
- // / Ensure that all block types were updated and then create an LLVM::BrOp
124
+ // / Convert the destination block signature (if necessary) and lower the branch
125
+ // / op to llvm.br.
132
126
struct BranchOpLowering : public ConvertOpToLLVMPattern <cf::BranchOp> {
133
127
using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
134
128
135
129
LogicalResult
136
130
matchAndRewrite (cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
137
131
ConversionPatternRewriter &rewriter) const override {
138
- if ( failed ( verifyMatchingValues (rewriter, adaptor. getDestOperands (),
139
- op.getSuccessor ()-> getArguments (),
140
- op. getLoc (),
141
- /* messagePrefix= */ " " ) ))
132
+ FailureOr<Block *> convertedBlock =
133
+ getConvertedBlock (rewriter, getTypeConverter (), op, op.getSuccessor (),
134
+ TypeRange (adaptor. getOperands ()));
135
+ if ( failed (convertedBlock ))
142
136
return failure ();
143
-
144
- rewriter.replaceOpWithNewOp <LLVM::BrOp>(
145
- op, adaptor.getOperands (), op->getSuccessors (), op->getAttrs ());
137
+ Operation *newOp = rewriter.replaceOpWithNewOp <LLVM::BrOp>(
138
+ op, adaptor.getOperands (), *convertedBlock);
139
+ // TODO: We should not just forward all attributes like that. But there are
140
+ // existing Flang tests that depend on this behavior.
141
+ newOp->setAttrs (op->getAttrDictionary ());
146
142
return success ();
147
143
}
148
144
};
149
145
150
- // / Ensure that all block types were updated and then create an LLVM::CondBrOp
146
+ // / Convert the destination block signatures (if necessary) and lower the
147
+ // / branch op to llvm.cond_br.
151
148
struct CondBranchOpLowering : public ConvertOpToLLVMPattern <cf::CondBranchOp> {
152
149
using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
153
150
154
151
LogicalResult
155
152
matchAndRewrite (cf::CondBranchOp op,
156
153
typename cf::CondBranchOp::Adaptor adaptor,
157
154
ConversionPatternRewriter &rewriter) const override {
158
- if (failed (verifyMatchingValues (rewriter, adaptor.getFalseDestOperands (),
159
- op.getFalseDest ()->getArguments (),
160
- op.getLoc (), " in false case branch " )))
155
+ FailureOr<Block *> convertedTrueBlock =
156
+ getConvertedBlock (rewriter, getTypeConverter (), op, op.getTrueDest (),
157
+ TypeRange (adaptor.getTrueDestOperands ()));
158
+ if (failed (convertedTrueBlock))
161
159
return failure ();
162
- if (failed (verifyMatchingValues (rewriter, adaptor.getTrueDestOperands (),
163
- op.getTrueDest ()->getArguments (),
164
- op.getLoc (), " in true case branch " )))
160
+ FailureOr<Block *> convertedFalseBlock =
161
+ getConvertedBlock (rewriter, getTypeConverter (), op, op.getFalseDest (),
162
+ TypeRange (adaptor.getFalseDestOperands ()));
163
+ if (failed (convertedFalseBlock))
165
164
return failure ();
166
-
167
- rewriter.replaceOpWithNewOp <LLVM::CondBrOp>(
168
- op, adaptor.getOperands (), op->getSuccessors (), op->getAttrs ());
165
+ Operation *newOp = rewriter.replaceOpWithNewOp <LLVM::CondBrOp>(
166
+ op, adaptor.getCondition (), *convertedTrueBlock,
167
+ adaptor.getTrueDestOperands (), *convertedFalseBlock,
168
+ adaptor.getFalseDestOperands ());
169
+ // TODO: We should not just forward all attributes like that. But there are
170
+ // existing Flang tests that depend on this behavior.
171
+ newOp->setAttrs (op->getAttrDictionary ());
169
172
return success ();
170
173
}
171
174
};
172
175
173
- // / Ensure that all block types were updated and then create an LLVM::SwitchOp
176
+ // / Convert the destination block signatures (if necessary) and lower the
177
+ // / switch op to llvm.switch.
174
178
struct SwitchOpLowering : public ConvertOpToLLVMPattern <cf::SwitchOp> {
175
179
using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
176
180
177
181
LogicalResult
178
182
matchAndRewrite (cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
179
183
ConversionPatternRewriter &rewriter) const override {
180
- if (failed (verifyMatchingValues (rewriter, adaptor.getDefaultOperands (),
181
- op.getDefaultDestination ()->getArguments (),
182
- op.getLoc (), " in switch default case " )))
184
+ // Get or convert default block.
185
+ FailureOr<Block *> convertedDefaultBlock = getConvertedBlock (
186
+ rewriter, getTypeConverter (), op, op.getDefaultDestination (),
187
+ TypeRange (adaptor.getDefaultOperands ()));
188
+ if (failed (convertedDefaultBlock))
183
189
return failure ();
184
190
185
- for (const auto &i : llvm::enumerate (
186
- llvm::zip (adaptor.getCaseOperands (), op.getCaseDestinations ()))) {
187
- if (failed (verifyMatchingValues (
188
- rewriter, std::get<0 >(i.value ()),
189
- std::get<1 >(i.value ())->getArguments (), op.getLoc (),
190
- " in switch case " + std::to_string (i.index ()) + " " ))) {
191
+ // Get or convert all case blocks.
192
+ SmallVector<Block *> caseDestinations;
193
+ SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands ();
194
+ for (auto it : llvm::enumerate (op.getCaseDestinations ())) {
195
+ Block *b = it.value ();
196
+ FailureOr<Block *> convertedBlock =
197
+ getConvertedBlock (rewriter, getTypeConverter (), op, b,
198
+ TypeRange (caseOperands[it.index ()]));
199
+ if (failed (convertedBlock))
191
200
return failure ();
192
- }
201
+ caseDestinations. push_back (*convertedBlock);
193
202
}
194
203
195
204
rewriter.replaceOpWithNewOp <LLVM::SwitchOp>(
196
- op, adaptor.getOperands (), op->getSuccessors (), op->getAttrs ());
205
+ op, adaptor.getFlag (), *convertedDefaultBlock,
206
+ adaptor.getDefaultOperands (), adaptor.getCaseValuesAttr (),
207
+ caseDestinations, caseOperands);
197
208
return success ();
198
209
}
199
210
};
@@ -230,14 +241,22 @@ struct ConvertControlFlowToLLVM
230
241
231
242
// / Run the dialect converter on the module.
232
243
void runOnOperation () override {
233
- LLVMConversionTarget target (getContext ());
234
- RewritePatternSet patterns (&getContext ());
235
-
236
- LowerToLLVMOptions options (&getContext ());
244
+ MLIRContext *ctx = &getContext ();
245
+ LLVMConversionTarget target (*ctx);
246
+ // This pass lowers only CF dialect ops, but it also modifies block
247
+ // signatures inside other ops. These ops should be treated as legal. They
248
+ // are lowered by other passes.
249
+ target.markUnknownOpDynamicallyLegal ([&](Operation *op) {
250
+ return op->getDialect () !=
251
+ ctx->getLoadedDialect <cf::ControlFlowDialect>();
252
+ });
253
+
254
+ LowerToLLVMOptions options (ctx);
237
255
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout )
238
256
options.overrideIndexBitwidth (indexBitwidth);
239
257
240
- LLVMTypeConverter converter (&getContext (), options);
258
+ LLVMTypeConverter converter (ctx, options);
259
+ RewritePatternSet patterns (ctx);
241
260
mlir::cf::populateControlFlowToLLVMConversionPatterns (converter, patterns);
242
261
243
262
if (failed (applyPartialConversion (getOperation (), target,
0 commit comments