Skip to content

Commit

Permalink
[CIR][CIRGen][Lowering] supports functions pointers (llvm#316)
Browse files Browse the repository at this point in the history
This PR adds a support of the function pointers in CIR.

From the implementation point of view, we emit an address of a function
as a `GlobalViewAttr`.
  • Loading branch information
gitoleg authored and lanza committed Jun 20, 2024
1 parent 1175fc0 commit 83ed47c
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 2 deletions.
10 changes: 8 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenExprConst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1155,8 +1155,14 @@ ConstantLValueEmitter::tryEmitBase(const APValue::LValueBase &base) {
if (D->hasAttr<WeakRefAttr>())
llvm_unreachable("emit pointer base for weakref is NYI");

if (auto *FD = dyn_cast<FunctionDecl>(D))
llvm_unreachable("emit pointer base for fun decl is NYI");
if (auto *FD = dyn_cast<FunctionDecl>(D)) {
auto fop = CGM.GetAddrOfFunction(FD);
auto builder = CGM.getBuilder();
auto ctxt = builder.getContext();
return mlir::cir::GlobalViewAttr::get(
builder.getPointerTo(fop.getFunctionType()),
mlir::FlatSymbolRefAttr::get(ctxt, fop.getSymNameAttr()));
}

if (auto *VD = dyn_cast<VarDecl>(D)) {
// We can never refer to a variable with local storage.
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp,
} else if (auto cirSymbol = dyn_cast<mlir::cir::GlobalOp>(sourceSymbol)) {
sourceType = converter->convertType(cirSymbol.getSymType());
symName = cirSymbol.getSymName();
} else if (auto llvmFun = dyn_cast<mlir::LLVM::LLVMFuncOp>(sourceSymbol)) {
sourceType = llvmFun.getFunctionType();
symName = llvmFun.getSymName();
} else {
llvm_unreachable("Unexpected GlobalOp type");
}
Expand Down
47 changes: 47 additions & 0 deletions clang/test/CIR/CodeGen/fun-ptr.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o - | FileCheck %s -check-prefix=CIR
// RUN: %clang_cc1 -x c++ -std=c++20 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o - | FileCheck %s -check-prefix=CIR
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o - | FileCheck %s -check-prefix=LLVM
// RUN: %clang_cc1 -x c++ -std=c++20 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o - | FileCheck %s -check-prefix=LLVM
// XFAIL: *

typedef struct {
int a;
int b;
} Data;

typedef int (*fun_t)(Data* d);

int extract_a(Data* d) {
return d->a;
}

// CIR: cir.func {{@.*foo.*}}(%arg0: !cir.ptr<!ty_22Data22>
// CIR: [[TMP0:%.*]] = cir.alloca !cir.ptr<!ty_22Data22>, cir.ptr <!cir.ptr<!ty_22Data22>>, ["d", init]
// CIR: [[TMP1:%.*]] = cir.alloca !s32i, cir.ptr <!s32i>, ["__retval"]
// CIR: [[TMP2:%.*]] = cir.alloca !cir.ptr<!cir.func<!s32i (!cir.ptr<!ty_22Data22>)>>, cir.ptr <!cir.ptr<!cir.func<!s32i (!cir.ptr<!ty_22Data22>)>>>, ["f", init]
// CIR: cir.store %arg0, [[TMP0]] : !cir.ptr<!ty_22Data22>, cir.ptr <!cir.ptr<!ty_22Data22>>
// CIR: [[TMP3:%.*]] = cir.const(#cir.ptr<null> : !cir.ptr<!cir.func<!s32i (!cir.ptr<!ty_22Data22>)>>) : !cir.ptr<!cir.func<!s32i (!cir.ptr<!ty_22Data22>)>>
// CIR: cir.store [[TMP3]], [[TMP2]] : !cir.ptr<!cir.func<!s32i (!cir.ptr<!ty_22Data22>)>>, cir.ptr <!cir.ptr<!cir.func<!s32i (!cir.ptr<!ty_22Data22>)>>>
// CIR: [[TMP4:%.*]] = cir.get_global {{@.*extract_a.*}} : cir.ptr <!cir.func<!s32i (!cir.ptr<!ty_22Data22>)>>
// CIR: cir.store [[TMP4]], [[TMP2]] : !cir.ptr<!cir.func<!s32i (!cir.ptr<!ty_22Data22>)>>, cir.ptr <!cir.ptr<!cir.func<!s32i (!cir.ptr<!ty_22Data22>)>>>
// CIR: [[TMP5:%.*]] = cir.load [[TMP2]] : cir.ptr <!cir.ptr<!cir.func<!s32i (!cir.ptr<!ty_22Data22>)>>>, !cir.ptr<!cir.func<!s32i (!cir.ptr<!ty_22Data22>)>>
// CIR: [[TMP6:%.*]] = cir.load [[TMP0]] : cir.ptr <!cir.ptr<!ty_22Data22>>, !cir.ptr<!ty_22Data22>
// CIR: [[TMP7:%.*]] = cir.call [[TMP5]]([[TMP6]]) : (!cir.ptr<!cir.func<!s32i (!cir.ptr<!ty_22Data22>)>>, !cir.ptr<!ty_22Data22>) -> !s32i
// CIR: cir.store [[TMP7]], [[TMP1]] : !s32i, cir.ptr <!s32i>

// LLVM: define i32 {{@.*foo.*}}(ptr %0)
// LLVM: [[TMP1:%.*]] = alloca ptr, i64 1
// LLVM: [[TMP2:%.*]] = alloca i32, i64 1
// LLVM: [[TMP3:%.*]] = alloca ptr, i64 1
// LLVM: store ptr %0, ptr [[TMP1]]
// LLVM: store ptr null, ptr [[TMP3]]
// LLVM: store ptr {{@.*extract_a.*}}, ptr [[TMP3]]
// LLVM: [[TMP4:%.*]] = load ptr, ptr [[TMP3]]
// LLVM: [[TMP5:%.*]] = load ptr, ptr [[TMP1]]
// LLVM: [[TMP6:%.*]] = call i32 [[TMP4]](ptr [[TMP5]])
// LLVM: store i32 [[TMP6]], ptr [[TMP2]]
int foo(Data* d) {
fun_t f = 0;
f = extract_a;
return f(d);
}

0 comments on commit 83ed47c

Please sign in to comment.