@@ -59,14 +59,16 @@ static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
5959 /* resultTypes=*/ {});
6060}
6161
62+ static mlir::Type getVoidPtrType (mlir::MLIRContext *context) {
63+ return fir::ReferenceType::get (mlir::NoneType::get (context));
64+ }
65+
6266// / This is for function result types that are of type C_PTR from ISO_C_BINDING.
6367// / Follow the ABI for interoperability with C.
6468static mlir::FunctionType getCPtrFunctionType (mlir::FunctionType funcTy) {
65- auto resultType = funcTy.getResult (0 );
66- assert (fir::isa_builtin_cptr_type (resultType));
67- llvm::SmallVector<mlir::Type> outputTypes;
68- auto recTy = mlir::dyn_cast<fir::RecordType>(resultType);
69- outputTypes.emplace_back (recTy.getTypeList ()[0 ].second );
69+ assert (fir::isa_builtin_cptr_type (funcTy.getResult (0 )));
70+ llvm::SmallVector<mlir::Type> outputTypes{
71+ getVoidPtrType (funcTy.getContext ())};
7072 return mlir::FunctionType::get (funcTy.getContext (), funcTy.getInputs (),
7173 outputTypes);
7274}
@@ -109,15 +111,11 @@ class CallConversion : public mlir::OpRewritePattern<Op> {
109111 saveResult.getTypeparams ());
110112
111113 llvm::SmallVector<mlir::Type> newResultTypes;
112- // TODO: This should be generalized for derived types, and it is
113- // architecture and OS dependent.
114114 bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type (result.getType ());
115- Op newOp;
116- if (isResultBuiltinCPtr) {
117- auto recTy = mlir::dyn_cast<fir::RecordType>(result.getType ());
118- newResultTypes.emplace_back (recTy.getTypeList ()[0 ].second );
119- }
115+ if (isResultBuiltinCPtr)
116+ newResultTypes.emplace_back (getVoidPtrType (result.getContext ()));
120117
118+ Op newOp;
121119 // fir::CallOp specific handling.
122120 if constexpr (std::is_same_v<Op, fir::CallOp>) {
123121 if (op.getCallee ()) {
@@ -175,7 +173,7 @@ class CallConversion : public mlir::OpRewritePattern<Op> {
175173 FirOpBuilder builder (rewriter, module );
176174 mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr (
177175 builder, loc, save, result.getType ());
178- rewriter. create <fir::StoreOp> (loc, newOp->getResult (0 ), saveAddr);
176+ builder. createStoreWithConvert (loc, newOp->getResult (0 ), saveAddr);
179177 }
180178 op->dropAllReferences ();
181179 rewriter.eraseOp (op);
@@ -210,42 +208,52 @@ class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
210208 mlir::PatternRewriter &rewriter) const override {
211209 auto loc = ret.getLoc ();
212210 rewriter.setInsertionPoint (ret);
213- auto returnedValue = ret.getOperand (0 );
214- bool replacedStorage = false ;
215- if (auto *op = returnedValue.getDefiningOp ())
216- if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
217- auto resultStorage = load.getMemref ();
218- // The result alloca may be behind a fir.declare, if any.
219- if (auto declare = mlir::dyn_cast_or_null<fir::DeclareOp>(
220- resultStorage.getDefiningOp ()))
221- resultStorage = declare.getMemref ();
222- // TODO: This should be generalized for derived types, and it is
223- // architecture and OS dependent.
224- if (fir::isa_builtin_cptr_type (returnedValue.getType ())) {
225- rewriter.eraseOp (load);
226- auto module = ret->getParentOfType <mlir::ModuleOp>();
227- FirOpBuilder builder (rewriter, module );
228- mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr (
229- builder, loc, resultStorage, returnedValue.getType ());
230- mlir::Value retValue = rewriter.create <fir::LoadOp>(
231- loc, fir::unwrapRefType (retAddr.getType ()), retAddr);
232- rewriter.replaceOpWithNewOp <mlir::func::ReturnOp>(
233- ret, mlir::ValueRange{retValue});
234- return mlir::success ();
235- }
236- resultStorage.replaceAllUsesWith (newArg);
237- replacedStorage = true ;
238- if (auto *alloc = resultStorage.getDefiningOp ())
239- if (alloc->use_empty ())
240- rewriter.eraseOp (alloc);
211+ mlir::Value resultValue = ret.getOperand (0 );
212+ fir::LoadOp resultLoad;
213+ mlir::Value resultStorage;
214+ // Identify result local storage.
215+ if (auto load = resultValue.getDefiningOp <fir::LoadOp>()) {
216+ resultLoad = load;
217+ resultStorage = load.getMemref ();
218+ // The result alloca may be behind a fir.declare, if any.
219+ if (auto declare = resultStorage.getDefiningOp <fir::DeclareOp>())
220+ resultStorage = declare.getMemref ();
221+ }
222+ // Replace old local storage with new storage argument, unless
223+ // the derived type is C_PTR/C_FUN_PTR, in which case the return
224+ // type is updated to return void* (no new argument is passed).
225+ if (fir::isa_builtin_cptr_type (resultValue.getType ())) {
226+ auto module = ret->getParentOfType <mlir::ModuleOp>();
227+ FirOpBuilder builder (rewriter, module );
228+ mlir::Value cptr = resultValue;
229+ if (resultLoad) {
230+ // Replace whole derived type load by component load.
231+ cptr = resultLoad.getMemref ();
232+ rewriter.setInsertionPoint (resultLoad);
241233 }
242- // The result storage may have been optimized out by a memory to
243- // register pass, this is possible for fir.box results, or fir.record
244- // with no length parameters. Simply store the result in the result storage.
245- // at the return point.
246- if (!replacedStorage)
247- rewriter.create <fir::StoreOp>(loc, returnedValue, newArg);
248- rewriter.replaceOpWithNewOp <mlir::func::ReturnOp>(ret);
234+ mlir::Value newResultValue =
235+ fir::factory::genCPtrOrCFunptrValue (builder, loc, cptr);
236+ newResultValue = builder.createConvert (
237+ loc, getVoidPtrType (ret.getContext ()), newResultValue);
238+ rewriter.setInsertionPoint (ret);
239+ rewriter.replaceOpWithNewOp <mlir::func::ReturnOp>(
240+ ret, mlir::ValueRange{newResultValue});
241+ } else if (resultStorage) {
242+ resultStorage.replaceAllUsesWith (newArg);
243+ rewriter.replaceOpWithNewOp <mlir::func::ReturnOp>(ret);
244+ } else {
245+ // The result storage may have been optimized out by a memory to
246+ // register pass, this is possible for fir.box results, or fir.record
247+ // with no length parameters. Simply store the result in the result
248+ // storage. at the return point.
249+ rewriter.create <fir::StoreOp>(loc, resultValue, newArg);
250+ rewriter.replaceOpWithNewOp <mlir::func::ReturnOp>(ret);
251+ }
252+ // Delete result old local storage if unused.
253+ if (resultStorage)
254+ if (auto alloc = resultStorage.getDefiningOp <fir::AllocaOp>())
255+ if (alloc->use_empty ())
256+ rewriter.eraseOp (alloc);
249257 return mlir::success ();
250258 }
251259
@@ -263,8 +271,6 @@ class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
263271 mlir::PatternRewriter &rewriter) const override {
264272 auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType ());
265273 mlir::FunctionType newFuncTy;
266- // TODO: This should be generalized for derived types, and it is
267- // architecture and OS dependent.
268274 if (oldFuncTy.getNumResults () != 0 &&
269275 fir::isa_builtin_cptr_type (oldFuncTy.getResult (0 )))
270276 newFuncTy = getCPtrFunctionType (oldFuncTy);
@@ -298,8 +304,6 @@ class AbstractResultOpt
298304 // Convert function type itself if it has an abstract result.
299305 auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType ());
300306 if (hasAbstractResult (funcTy)) {
301- // TODO: This should be generalized for derived types, and it is
302- // architecture and OS dependent.
303307 if (fir::isa_builtin_cptr_type (funcTy.getResult (0 ))) {
304308 func.setType (getCPtrFunctionType (funcTy));
305309 patterns.insert <ReturnOpConversion>(context, mlir::Value{});
0 commit comments