From 83ed47c5e86d3c478f18a836ca54a405336c912c Mon Sep 17 00:00:00 2001 From: gitoleg Date: Mon, 20 Nov 2023 22:24:43 +0300 Subject: [PATCH] [CIR][CIRGen][Lowering] supports functions pointers (#316) This PR adds a support of the function pointers in CIR. From the implementation point of view, we emit an address of a function as a `GlobalViewAttr`. --- clang/lib/CIR/CodeGen/CIRGenExprConst.cpp | 10 +++- .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 3 ++ clang/test/CIR/CodeGen/fun-ptr.c | 47 +++++++++++++++++++ 3 files changed, 58 insertions(+), 2 deletions(-) create mode 100644 clang/test/CIR/CodeGen/fun-ptr.c diff --git a/clang/lib/CIR/CodeGen/CIRGenExprConst.cpp b/clang/lib/CIR/CodeGen/CIRGenExprConst.cpp index f2db6408eef542..79621c31e3dd67 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprConst.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprConst.cpp @@ -1155,8 +1155,14 @@ ConstantLValueEmitter::tryEmitBase(const APValue::LValueBase &base) { if (D->hasAttr()) llvm_unreachable("emit pointer base for weakref is NYI"); - if (auto *FD = dyn_cast(D)) - llvm_unreachable("emit pointer base for fun decl is NYI"); + if (auto *FD = dyn_cast(D)) { + auto fop = CGM.GetAddrOfFunction(FD); + auto builder = CGM.getBuilder(); + auto ctxt = builder.getContext(); + return mlir::cir::GlobalViewAttr::get( + builder.getPointerTo(fop.getFunctionType()), + mlir::FlatSymbolRefAttr::get(ctxt, fop.getSymNameAttr())); + } if (auto *VD = dyn_cast(D)) { // We can never refer to a variable with local storage. diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index d9f9dbed8d5e63..5e4a3d6eb26063 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -236,6 +236,9 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, } else if (auto cirSymbol = dyn_cast(sourceSymbol)) { sourceType = converter->convertType(cirSymbol.getSymType()); symName = cirSymbol.getSymName(); + } else if (auto llvmFun = dyn_cast(sourceSymbol)) { + sourceType = llvmFun.getFunctionType(); + symName = llvmFun.getSymName(); } else { llvm_unreachable("Unexpected GlobalOp type"); } diff --git a/clang/test/CIR/CodeGen/fun-ptr.c b/clang/test/CIR/CodeGen/fun-ptr.c new file mode 100644 index 00000000000000..d9d4a7809bc2dd --- /dev/null +++ b/clang/test/CIR/CodeGen/fun-ptr.c @@ -0,0 +1,47 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o - | FileCheck %s -check-prefix=CIR +// RUN: %clang_cc1 -x c++ -std=c++20 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o - | FileCheck %s -check-prefix=CIR +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o - | FileCheck %s -check-prefix=LLVM +// RUN: %clang_cc1 -x c++ -std=c++20 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o - | FileCheck %s -check-prefix=LLVM +// XFAIL: * + +typedef struct { + int a; + int b; +} Data; + +typedef int (*fun_t)(Data* d); + +int extract_a(Data* d) { + return d->a; +} + +// CIR: cir.func {{@.*foo.*}}(%arg0: !cir.ptr +// CIR: [[TMP0:%.*]] = cir.alloca !cir.ptr, cir.ptr >, ["d", init] +// CIR: [[TMP1:%.*]] = cir.alloca !s32i, cir.ptr , ["__retval"] +// CIR: [[TMP2:%.*]] = cir.alloca !cir.ptr)>>, cir.ptr )>>>, ["f", init] +// CIR: cir.store %arg0, [[TMP0]] : !cir.ptr, cir.ptr > +// CIR: [[TMP3:%.*]] = cir.const(#cir.ptr : !cir.ptr)>>) : !cir.ptr)>> +// CIR: cir.store [[TMP3]], [[TMP2]] : !cir.ptr)>>, cir.ptr )>>> +// CIR: [[TMP4:%.*]] = cir.get_global {{@.*extract_a.*}} : cir.ptr )>> +// CIR: cir.store [[TMP4]], [[TMP2]] : !cir.ptr)>>, cir.ptr )>>> +// CIR: [[TMP5:%.*]] = cir.load [[TMP2]] : cir.ptr )>>>, !cir.ptr)>> +// CIR: [[TMP6:%.*]] = cir.load [[TMP0]] : cir.ptr >, !cir.ptr +// CIR: [[TMP7:%.*]] = cir.call [[TMP5]]([[TMP6]]) : (!cir.ptr)>>, !cir.ptr) -> !s32i +// CIR: cir.store [[TMP7]], [[TMP1]] : !s32i, cir.ptr + +// LLVM: define i32 {{@.*foo.*}}(ptr %0) +// LLVM: [[TMP1:%.*]] = alloca ptr, i64 1 +// LLVM: [[TMP2:%.*]] = alloca i32, i64 1 +// LLVM: [[TMP3:%.*]] = alloca ptr, i64 1 +// LLVM: store ptr %0, ptr [[TMP1]] +// LLVM: store ptr null, ptr [[TMP3]] +// LLVM: store ptr {{@.*extract_a.*}}, ptr [[TMP3]] +// LLVM: [[TMP4:%.*]] = load ptr, ptr [[TMP3]] +// LLVM: [[TMP5:%.*]] = load ptr, ptr [[TMP1]] +// LLVM: [[TMP6:%.*]] = call i32 [[TMP4]](ptr [[TMP5]]) +// LLVM: store i32 [[TMP6]], ptr [[TMP2]] +int foo(Data* d) { + fun_t f = 0; + f = extract_a; + return f(d); +}