@@ -59,14 +59,16 @@ static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
59
59
/* resultTypes=*/ {});
60
60
}
61
61
62
+ static mlir::Type getVoidPtrType (mlir::MLIRContext *context) {
63
+ return fir::ReferenceType::get (mlir::NoneType::get (context));
64
+ }
65
+
62
66
// / This is for function result types that are of type C_PTR from ISO_C_BINDING.
63
67
// / Follow the ABI for interoperability with C.
64
68
static mlir::FunctionType getCPtrFunctionType (mlir::FunctionType funcTy) {
65
- auto resultType = funcTy.getResult (0 );
66
- assert (fir::isa_builtin_cptr_type (resultType));
67
- llvm::SmallVector<mlir::Type> outputTypes;
68
- auto recTy = mlir::dyn_cast<fir::RecordType>(resultType);
69
- outputTypes.emplace_back (recTy.getTypeList ()[0 ].second );
69
+ assert (fir::isa_builtin_cptr_type (funcTy.getResult (0 )));
70
+ llvm::SmallVector<mlir::Type> outputTypes{
71
+ getVoidPtrType (funcTy.getContext ())};
70
72
return mlir::FunctionType::get (funcTy.getContext (), funcTy.getInputs (),
71
73
outputTypes);
72
74
}
@@ -109,15 +111,11 @@ class CallConversion : public mlir::OpRewritePattern<Op> {
109
111
saveResult.getTypeparams ());
110
112
111
113
llvm::SmallVector<mlir::Type> newResultTypes;
112
- // TODO: This should be generalized for derived types, and it is
113
- // architecture and OS dependent.
114
114
bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type (result.getType ());
115
- Op newOp;
116
- if (isResultBuiltinCPtr) {
117
- auto recTy = mlir::dyn_cast<fir::RecordType>(result.getType ());
118
- newResultTypes.emplace_back (recTy.getTypeList ()[0 ].second );
119
- }
115
+ if (isResultBuiltinCPtr)
116
+ newResultTypes.emplace_back (getVoidPtrType (result.getContext ()));
120
117
118
+ Op newOp;
121
119
// fir::CallOp specific handling.
122
120
if constexpr (std::is_same_v<Op, fir::CallOp>) {
123
121
if (op.getCallee ()) {
@@ -175,7 +173,7 @@ class CallConversion : public mlir::OpRewritePattern<Op> {
175
173
FirOpBuilder builder (rewriter, module);
176
174
mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr (
177
175
builder, loc, save, result.getType ());
178
- rewriter. create <fir::StoreOp> (loc, newOp->getResult (0 ), saveAddr);
176
+ builder. createStoreWithConvert (loc, newOp->getResult (0 ), saveAddr);
179
177
}
180
178
op->dropAllReferences ();
181
179
rewriter.eraseOp (op);
@@ -210,42 +208,52 @@ class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
210
208
mlir::PatternRewriter &rewriter) const override {
211
209
auto loc = ret.getLoc ();
212
210
rewriter.setInsertionPoint (ret);
213
- auto returnedValue = ret.getOperand (0 );
214
- bool replacedStorage = false ;
215
- if (auto *op = returnedValue.getDefiningOp ())
216
- if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
217
- auto resultStorage = load.getMemref ();
218
- // The result alloca may be behind a fir.declare, if any.
219
- if (auto declare = mlir::dyn_cast_or_null<fir::DeclareOp>(
220
- resultStorage.getDefiningOp ()))
221
- resultStorage = declare.getMemref ();
222
- // TODO: This should be generalized for derived types, and it is
223
- // architecture and OS dependent.
224
- if (fir::isa_builtin_cptr_type (returnedValue.getType ())) {
225
- rewriter.eraseOp (load);
226
- auto module = ret->getParentOfType <mlir::ModuleOp>();
227
- FirOpBuilder builder (rewriter, module);
228
- mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr (
229
- builder, loc, resultStorage, returnedValue.getType ());
230
- mlir::Value retValue = rewriter.create <fir::LoadOp>(
231
- loc, fir::unwrapRefType (retAddr.getType ()), retAddr);
232
- rewriter.replaceOpWithNewOp <mlir::func::ReturnOp>(
233
- ret, mlir::ValueRange{retValue});
234
- return mlir::success ();
235
- }
236
- resultStorage.replaceAllUsesWith (newArg);
237
- replacedStorage = true ;
238
- if (auto *alloc = resultStorage.getDefiningOp ())
239
- if (alloc->use_empty ())
240
- rewriter.eraseOp (alloc);
211
+ mlir::Value resultValue = ret.getOperand (0 );
212
+ fir::LoadOp resultLoad;
213
+ mlir::Value resultStorage;
214
+ // Identify result local storage.
215
+ if (auto load = resultValue.getDefiningOp <fir::LoadOp>()) {
216
+ resultLoad = load;
217
+ resultStorage = load.getMemref ();
218
+ // The result alloca may be behind a fir.declare, if any.
219
+ if (auto declare = resultStorage.getDefiningOp <fir::DeclareOp>())
220
+ resultStorage = declare.getMemref ();
221
+ }
222
+ // Replace old local storage with new storage argument, unless
223
+ // the derived type is C_PTR/C_FUN_PTR, in which case the return
224
+ // type is updated to return void* (no new argument is passed).
225
+ if (fir::isa_builtin_cptr_type (resultValue.getType ())) {
226
+ auto module = ret->getParentOfType <mlir::ModuleOp>();
227
+ FirOpBuilder builder (rewriter, module);
228
+ mlir::Value cptr = resultValue;
229
+ if (resultLoad) {
230
+ // Replace whole derived type load by component load.
231
+ cptr = resultLoad.getMemref ();
232
+ rewriter.setInsertionPoint (resultLoad);
241
233
}
242
- // The result storage may have been optimized out by a memory to
243
- // register pass, this is possible for fir.box results, or fir.record
244
- // with no length parameters. Simply store the result in the result storage.
245
- // at the return point.
246
- if (!replacedStorage)
247
- rewriter.create <fir::StoreOp>(loc, returnedValue, newArg);
248
- rewriter.replaceOpWithNewOp <mlir::func::ReturnOp>(ret);
234
+ mlir::Value newResultValue =
235
+ fir::factory::genCPtrOrCFunptrValue (builder, loc, cptr);
236
+ newResultValue = builder.createConvert (
237
+ loc, getVoidPtrType (ret.getContext ()), newResultValue);
238
+ rewriter.setInsertionPoint (ret);
239
+ rewriter.replaceOpWithNewOp <mlir::func::ReturnOp>(
240
+ ret, mlir::ValueRange{newResultValue});
241
+ } else if (resultStorage) {
242
+ resultStorage.replaceAllUsesWith (newArg);
243
+ rewriter.replaceOpWithNewOp <mlir::func::ReturnOp>(ret);
244
+ } else {
245
+ // The result storage may have been optimized out by a memory to
246
+ // register pass, this is possible for fir.box results, or fir.record
247
+ // with no length parameters. Simply store the result in the result
248
+ // storage. at the return point.
249
+ rewriter.create <fir::StoreOp>(loc, resultValue, newArg);
250
+ rewriter.replaceOpWithNewOp <mlir::func::ReturnOp>(ret);
251
+ }
252
+ // Delete result old local storage if unused.
253
+ if (resultStorage)
254
+ if (auto alloc = resultStorage.getDefiningOp <fir::AllocaOp>())
255
+ if (alloc->use_empty ())
256
+ rewriter.eraseOp (alloc);
249
257
return mlir::success ();
250
258
}
251
259
@@ -263,8 +271,6 @@ class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
263
271
mlir::PatternRewriter &rewriter) const override {
264
272
auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType ());
265
273
mlir::FunctionType newFuncTy;
266
- // TODO: This should be generalized for derived types, and it is
267
- // architecture and OS dependent.
268
274
if (oldFuncTy.getNumResults () != 0 &&
269
275
fir::isa_builtin_cptr_type (oldFuncTy.getResult (0 )))
270
276
newFuncTy = getCPtrFunctionType (oldFuncTy);
@@ -298,8 +304,6 @@ class AbstractResultOpt
298
304
// Convert function type itself if it has an abstract result.
299
305
auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType ());
300
306
if (hasAbstractResult (funcTy)) {
301
- // TODO: This should be generalized for derived types, and it is
302
- // architecture and OS dependent.
303
307
if (fir::isa_builtin_cptr_type (funcTy.getResult (0 ))) {
304
308
func.setType (getCPtrFunctionType (funcTy));
305
309
patterns.insert <ReturnOpConversion>(context, mlir::Value{});
0 commit comments