Skip to content

Commit 5f8189c

Browse files
jeanPeriertru
authored andcommitted
[flang] fix C_PTR function result lowering (#100082)
Functions returning C_PTR were lowered to function returning intptr (i64 on 64bit arch). This caused conflicts when these functions were defined as returning !fir.ref<none>/llvm.ptr in other compiler generated contexts (e.g., malloc). Lower them to return !fir.ref<none>. This should deal with #97325 and #98644. (cherry picked from commit 1ead51a)
1 parent fc0b1ce commit 5f8189c

File tree

3 files changed

+110
-88
lines changed

3 files changed

+110
-88
lines changed

flang/lib/Optimizer/Builder/FIRBuilder.cpp

+34-20
Original file line numberDiff line numberDiff line change
@@ -1541,21 +1541,44 @@ mlir::Value fir::factory::genMaxWithZero(fir::FirOpBuilder &builder,
15411541
zero);
15421542
}
15431543

1544+
static std::pair<mlir::Value, mlir::Type>
1545+
genCPtrOrCFunptrFieldIndex(fir::FirOpBuilder &builder, mlir::Location loc,
1546+
mlir::Type cptrTy) {
1547+
auto recTy = mlir::cast<fir::RecordType>(cptrTy);
1548+
assert(recTy.getTypeList().size() == 1);
1549+
auto addrFieldName = recTy.getTypeList()[0].first;
1550+
mlir::Type addrFieldTy = recTy.getTypeList()[0].second;
1551+
auto fieldIndexType = fir::FieldType::get(cptrTy.getContext());
1552+
mlir::Value addrFieldIndex = builder.create<fir::FieldIndexOp>(
1553+
loc, fieldIndexType, addrFieldName, recTy,
1554+
/*typeParams=*/mlir::ValueRange{});
1555+
return {addrFieldIndex, addrFieldTy};
1556+
}
1557+
15441558
mlir::Value fir::factory::genCPtrOrCFunptrAddr(fir::FirOpBuilder &builder,
15451559
mlir::Location loc,
15461560
mlir::Value cPtr,
15471561
mlir::Type ty) {
1548-
assert(mlir::isa<fir::RecordType>(ty));
1549-
auto recTy = mlir::dyn_cast<fir::RecordType>(ty);
1550-
assert(recTy.getTypeList().size() == 1);
1551-
auto fieldName = recTy.getTypeList()[0].first;
1552-
mlir::Type fieldTy = recTy.getTypeList()[0].second;
1553-
auto fieldIndexType = fir::FieldType::get(ty.getContext());
1554-
mlir::Value field =
1555-
builder.create<fir::FieldIndexOp>(loc, fieldIndexType, fieldName, recTy,
1556-
/*typeParams=*/mlir::ValueRange{});
1557-
return builder.create<fir::CoordinateOp>(loc, builder.getRefType(fieldTy),
1558-
cPtr, field);
1562+
auto [addrFieldIndex, addrFieldTy] =
1563+
genCPtrOrCFunptrFieldIndex(builder, loc, ty);
1564+
return builder.create<fir::CoordinateOp>(loc, builder.getRefType(addrFieldTy),
1565+
cPtr, addrFieldIndex);
1566+
}
1567+
1568+
mlir::Value fir::factory::genCPtrOrCFunptrValue(fir::FirOpBuilder &builder,
1569+
mlir::Location loc,
1570+
mlir::Value cPtr) {
1571+
mlir::Type cPtrTy = fir::unwrapRefType(cPtr.getType());
1572+
if (fir::isa_ref_type(cPtr.getType())) {
1573+
mlir::Value cPtrAddr =
1574+
fir::factory::genCPtrOrCFunptrAddr(builder, loc, cPtr, cPtrTy);
1575+
return builder.create<fir::LoadOp>(loc, cPtrAddr);
1576+
}
1577+
auto [addrFieldIndex, addrFieldTy] =
1578+
genCPtrOrCFunptrFieldIndex(builder, loc, cPtrTy);
1579+
auto arrayAttr =
1580+
builder.getArrayAttr({builder.getIntegerAttr(builder.getIndexType(), 0)});
1581+
return builder.create<fir::ExtractValueOp>(loc, addrFieldTy, cPtr, arrayAttr);
15591582
}
15601583

15611584
fir::BoxValue fir::factory::createBoxValue(fir::FirOpBuilder &builder,
@@ -1596,15 +1619,6 @@ fir::BoxValue fir::factory::createBoxValue(fir::FirOpBuilder &builder,
15961619
return fir::BoxValue(box, lbounds, explicitTypeParams);
15971620
}
15981621

1599-
mlir::Value fir::factory::genCPtrOrCFunptrValue(fir::FirOpBuilder &builder,
1600-
mlir::Location loc,
1601-
mlir::Value cPtr) {
1602-
mlir::Type cPtrTy = fir::unwrapRefType(cPtr.getType());
1603-
mlir::Value cPtrAddr =
1604-
fir::factory::genCPtrOrCFunptrAddr(builder, loc, cPtr, cPtrTy);
1605-
return builder.create<fir::LoadOp>(loc, cPtrAddr);
1606-
}
1607-
16081622
mlir::Value fir::factory::createNullBoxProc(fir::FirOpBuilder &builder,
16091623
mlir::Location loc,
16101624
mlir::Type boxType) {

flang/lib/Optimizer/Transforms/AbstractResult.cpp

+56-52
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,16 @@ static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
5959
/*resultTypes=*/{});
6060
}
6161

62+
static mlir::Type getVoidPtrType(mlir::MLIRContext *context) {
63+
return fir::ReferenceType::get(mlir::NoneType::get(context));
64+
}
65+
6266
/// This is for function result types that are of type C_PTR from ISO_C_BINDING.
6367
/// Follow the ABI for interoperability with C.
6468
static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) {
65-
auto resultType = funcTy.getResult(0);
66-
assert(fir::isa_builtin_cptr_type(resultType));
67-
llvm::SmallVector<mlir::Type> outputTypes;
68-
auto recTy = mlir::dyn_cast<fir::RecordType>(resultType);
69-
outputTypes.emplace_back(recTy.getTypeList()[0].second);
69+
assert(fir::isa_builtin_cptr_type(funcTy.getResult(0)));
70+
llvm::SmallVector<mlir::Type> outputTypes{
71+
getVoidPtrType(funcTy.getContext())};
7072
return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(),
7173
outputTypes);
7274
}
@@ -109,15 +111,11 @@ class CallConversion : public mlir::OpRewritePattern<Op> {
109111
saveResult.getTypeparams());
110112

111113
llvm::SmallVector<mlir::Type> newResultTypes;
112-
// TODO: This should be generalized for derived types, and it is
113-
// architecture and OS dependent.
114114
bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
115-
Op newOp;
116-
if (isResultBuiltinCPtr) {
117-
auto recTy = mlir::dyn_cast<fir::RecordType>(result.getType());
118-
newResultTypes.emplace_back(recTy.getTypeList()[0].second);
119-
}
115+
if (isResultBuiltinCPtr)
116+
newResultTypes.emplace_back(getVoidPtrType(result.getContext()));
120117

118+
Op newOp;
121119
// fir::CallOp specific handling.
122120
if constexpr (std::is_same_v<Op, fir::CallOp>) {
123121
if (op.getCallee()) {
@@ -175,7 +173,7 @@ class CallConversion : public mlir::OpRewritePattern<Op> {
175173
FirOpBuilder builder(rewriter, module);
176174
mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
177175
builder, loc, save, result.getType());
178-
rewriter.create<fir::StoreOp>(loc, newOp->getResult(0), saveAddr);
176+
builder.createStoreWithConvert(loc, newOp->getResult(0), saveAddr);
179177
}
180178
op->dropAllReferences();
181179
rewriter.eraseOp(op);
@@ -210,42 +208,52 @@ class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
210208
mlir::PatternRewriter &rewriter) const override {
211209
auto loc = ret.getLoc();
212210
rewriter.setInsertionPoint(ret);
213-
auto returnedValue = ret.getOperand(0);
214-
bool replacedStorage = false;
215-
if (auto *op = returnedValue.getDefiningOp())
216-
if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
217-
auto resultStorage = load.getMemref();
218-
// The result alloca may be behind a fir.declare, if any.
219-
if (auto declare = mlir::dyn_cast_or_null<fir::DeclareOp>(
220-
resultStorage.getDefiningOp()))
221-
resultStorage = declare.getMemref();
222-
// TODO: This should be generalized for derived types, and it is
223-
// architecture and OS dependent.
224-
if (fir::isa_builtin_cptr_type(returnedValue.getType())) {
225-
rewriter.eraseOp(load);
226-
auto module = ret->getParentOfType<mlir::ModuleOp>();
227-
FirOpBuilder builder(rewriter, module);
228-
mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr(
229-
builder, loc, resultStorage, returnedValue.getType());
230-
mlir::Value retValue = rewriter.create<fir::LoadOp>(
231-
loc, fir::unwrapRefType(retAddr.getType()), retAddr);
232-
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
233-
ret, mlir::ValueRange{retValue});
234-
return mlir::success();
235-
}
236-
resultStorage.replaceAllUsesWith(newArg);
237-
replacedStorage = true;
238-
if (auto *alloc = resultStorage.getDefiningOp())
239-
if (alloc->use_empty())
240-
rewriter.eraseOp(alloc);
211+
mlir::Value resultValue = ret.getOperand(0);
212+
fir::LoadOp resultLoad;
213+
mlir::Value resultStorage;
214+
// Identify result local storage.
215+
if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) {
216+
resultLoad = load;
217+
resultStorage = load.getMemref();
218+
// The result alloca may be behind a fir.declare, if any.
219+
if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>())
220+
resultStorage = declare.getMemref();
221+
}
222+
// Replace old local storage with new storage argument, unless
223+
// the derived type is C_PTR/C_FUN_PTR, in which case the return
224+
// type is updated to return void* (no new argument is passed).
225+
if (fir::isa_builtin_cptr_type(resultValue.getType())) {
226+
auto module = ret->getParentOfType<mlir::ModuleOp>();
227+
FirOpBuilder builder(rewriter, module);
228+
mlir::Value cptr = resultValue;
229+
if (resultLoad) {
230+
// Replace whole derived type load by component load.
231+
cptr = resultLoad.getMemref();
232+
rewriter.setInsertionPoint(resultLoad);
241233
}
242-
// The result storage may have been optimized out by a memory to
243-
// register pass, this is possible for fir.box results, or fir.record
244-
// with no length parameters. Simply store the result in the result storage.
245-
// at the return point.
246-
if (!replacedStorage)
247-
rewriter.create<fir::StoreOp>(loc, returnedValue, newArg);
248-
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
234+
mlir::Value newResultValue =
235+
fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr);
236+
newResultValue = builder.createConvert(
237+
loc, getVoidPtrType(ret.getContext()), newResultValue);
238+
rewriter.setInsertionPoint(ret);
239+
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
240+
ret, mlir::ValueRange{newResultValue});
241+
} else if (resultStorage) {
242+
resultStorage.replaceAllUsesWith(newArg);
243+
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
244+
} else {
245+
// The result storage may have been optimized out by a memory to
246+
// register pass, this is possible for fir.box results, or fir.record
247+
// with no length parameters. Simply store the result in the result
248+
// storage. at the return point.
249+
rewriter.create<fir::StoreOp>(loc, resultValue, newArg);
250+
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
251+
}
252+
// Delete result old local storage if unused.
253+
if (resultStorage)
254+
if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>())
255+
if (alloc->use_empty())
256+
rewriter.eraseOp(alloc);
249257
return mlir::success();
250258
}
251259

@@ -263,8 +271,6 @@ class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
263271
mlir::PatternRewriter &rewriter) const override {
264272
auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType());
265273
mlir::FunctionType newFuncTy;
266-
// TODO: This should be generalized for derived types, and it is
267-
// architecture and OS dependent.
268274
if (oldFuncTy.getNumResults() != 0 &&
269275
fir::isa_builtin_cptr_type(oldFuncTy.getResult(0)))
270276
newFuncTy = getCPtrFunctionType(oldFuncTy);
@@ -298,8 +304,6 @@ class AbstractResultOpt
298304
// Convert function type itself if it has an abstract result.
299305
auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
300306
if (hasAbstractResult(funcTy)) {
301-
// TODO: This should be generalized for derived types, and it is
302-
// architecture and OS dependent.
303307
if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) {
304308
func.setType(getCPtrFunctionType(funcTy));
305309
patterns.insert<ReturnOpConversion>(context, mlir::Value{});

flang/test/Fir/abstract-results.fir

+20-16
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ func.func @boxfunc_callee() -> !fir.box<!fir.heap<f64>> {
8787
// FUNC-BOX: return
8888
}
8989

90-
// FUNC-REF-LABEL: func @retcptr() -> i64
91-
// FUNC-BOX-LABEL: func @retcptr() -> i64
90+
// FUNC-REF-LABEL: func @retcptr() -> !fir.ref<none>
91+
// FUNC-BOX-LABEL: func @retcptr() -> !fir.ref<none>
9292
func.func @retcptr() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {
9393
%0 = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "rec", uniq_name = "_QFrecErec"}
9494
%1 = fir.load %0 : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>
@@ -98,12 +98,14 @@ func.func @retcptr() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__addres
9898
// FUNC-REF: %[[FIELD:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
9999
// FUNC-REF: %[[ADDR:.*]] = fir.coordinate_of %[[ALLOC]], %[[FIELD]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
100100
// FUNC-REF: %[[VAL:.*]] = fir.load %[[ADDR]] : !fir.ref<i64>
101-
// FUNC-REF: return %[[VAL]] : i64
101+
// FUNC-REF: %[[CAST:.*]] = fir.convert %[[VAL]] : (i64) -> !fir.ref<none>
102+
// FUNC-REF: return %[[CAST]] : !fir.ref<none>
102103
// FUNC-BOX: %[[ALLOC:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "rec", uniq_name = "_QFrecErec"}
103104
// FUNC-BOX: %[[FIELD:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
104105
// FUNC-BOX: %[[ADDR:.*]] = fir.coordinate_of %[[ALLOC]], %[[FIELD]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
105106
// FUNC-BOX: %[[VAL:.*]] = fir.load %[[ADDR]] : !fir.ref<i64>
106-
// FUNC-BOX: return %[[VAL]] : i64
107+
// FUNC-BOX: %[[CAST:.*]] = fir.convert %[[VAL]] : (i64) -> !fir.ref<none>
108+
// FUNC-BOX: return %[[CAST]] : !fir.ref<none>
107109
}
108110

109111
// FUNC-REF-LABEL: func private @arrayfunc_callee_declare(
@@ -311,8 +313,8 @@ func.func @test_address_of() {
311313

312314
}
313315

314-
// FUNC-REF-LABEL: func.func private @returns_null() -> i64
315-
// FUNC-BOX-LABEL: func.func private @returns_null() -> i64
316+
// FUNC-REF-LABEL: func.func private @returns_null() -> !fir.ref<none>
317+
// FUNC-BOX-LABEL: func.func private @returns_null() -> !fir.ref<none>
316318
func.func private @returns_null() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
317319

318320
// FUNC-REF-LABEL: func @test_address_of_cptr
@@ -323,12 +325,12 @@ func.func @test_address_of_cptr() {
323325
fir.call @_QMtest_c_func_modPsubr(%1) : (() -> ()) -> ()
324326
return
325327

326-
// FUNC-REF: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> i64
327-
// FUNC-REF: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> i64) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
328+
// FUNC-REF: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> !fir.ref<none>
329+
// FUNC-REF: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> !fir.ref<none>) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
328330
// FUNC-REF: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> ())
329331
// FUNC-REF: fir.call @_QMtest_c_func_modPsubr(%[[VAL_2]]) : (() -> ()) -> ()
330-
// FUNC-BOX: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> i64
331-
// FUNC-BOX: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> i64) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
332+
// FUNC-BOX: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> !fir.ref<none>
333+
// FUNC-BOX: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> !fir.ref<none>) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
332334
// FUNC-BOX: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> ())
333335
// FUNC-BOX: fir.call @_QMtest_c_func_modPsubr(%[[VAL_2]]) : (() -> ()) -> ()
334336
}
@@ -380,18 +382,20 @@ func.func @test_indirect_calls_return_cptr(%arg0: () -> ()) {
380382

381383
// FUNC-REF: %[[VAL_0:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = ".result"}
382384
// FUNC-REF: %[[VAL_1:.*]] = fir.convert %[[ARG0]] : (() -> ()) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
383-
// FUNC-REF: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> i64)
384-
// FUNC-REF: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> i64
385+
// FUNC-REF: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> !fir.ref<none>)
386+
// FUNC-REF: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> !fir.ref<none>
385387
// FUNC-REF: %[[VAL_4:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
386388
// FUNC-REF: %[[VAL_5:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_4]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
387-
// FUNC-REF: fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<i64>
389+
// FUNC-REF: %[[CAST:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<none>) -> i64
390+
// FUNC-REF: fir.store %[[CAST]] to %[[VAL_5]] : !fir.ref<i64>
388391
// FUNC-BOX: %[[VAL_0:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = ".result"}
389392
// FUNC-BOX: %[[VAL_1:.*]] = fir.convert %[[ARG0]] : (() -> ()) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
390-
// FUNC-BOX: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> i64)
391-
// FUNC-BOX: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> i64
393+
// FUNC-BOX: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> !fir.ref<none>)
394+
// FUNC-BOX: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> !fir.ref<none>
392395
// FUNC-BOX: %[[VAL_4:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
393396
// FUNC-BOX: %[[VAL_5:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_4]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
394-
// FUNC-BOX: fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<i64>
397+
// FUNC-BOX: %[[CAST:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<none>) -> i64
398+
// FUNC-BOX: fir.store %[[CAST]] to %[[VAL_5]] : !fir.ref<i64>
395399
}
396400

397401
// ----------------------- Test GlobalOp rewrite ------------------------

0 commit comments

Comments
 (0)