-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flang][cuda] Translate cuf.register_kernel and cuf.register_module #112972
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-flang-runtime Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesAdd LLVM IR Translation for Full diff: https://github.com/llvm/llvm-project/pull/112972.diff 8 Files Affected:
diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h b/flang/include/flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h
new file mode 100644
index 00000000000000..f3edb7fca649d0
--- /dev/null
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h
@@ -0,0 +1,29 @@
+//===- CUFToLLVMIRTranslation.h - CUF Dialect to LLVM IR --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This provides registration calls for GPU dialect to LLVM IR translation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FLANG_OPTIMIZER_DIALECT_CUF_GPUTOLLVMIRTRANSLATION_H_
+#define FLANG_OPTIMIZER_DIALECT_CUF_GPUTOLLVMIRTRANSLATION_H_
+
+namespace mlir {
+class DialectRegistry;
+class MLIRContext;
+} // namespace mlir
+
+namespace cuf {
+
+/// Register the CUF dialect and the translation from it to the LLVM IR in
+/// the given registry.
+void registerCUFDialectTranslation(mlir::DialectRegistry ®istry);
+
+} // namespace cuf
+
+#endif // FLANG_OPTIMIZER_DIALECT_CUF_GPUTOLLVMIRTRANSLATION_H_
diff --git a/flang/include/flang/Optimizer/Support/InitFIR.h b/flang/include/flang/Optimizer/Support/InitFIR.h
index 04a5dd323e5508..1c61c367199923 100644
--- a/flang/include/flang/Optimizer/Support/InitFIR.h
+++ b/flang/include/flang/Optimizer/Support/InitFIR.h
@@ -14,6 +14,7 @@
#define FORTRAN_OPTIMIZER_SUPPORT_INITFIR_H
#include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
+#include "flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "mlir/Conversion/Passes.h"
@@ -61,6 +62,7 @@ inline void addFIRExtensions(mlir::DialectRegistry ®istry,
if (addFIRInlinerInterface)
addFIRInlinerExtension(registry);
addFIRToLLVMIRExtension(registry);
+ cuf::registerCUFDialectTranslation(registry);
}
inline void loadNonCodegenDialects(mlir::MLIRContext &context) {
diff --git a/flang/include/flang/Runtime/CUDA/registration.h b/flang/include/flang/Runtime/CUDA/registration.h
new file mode 100644
index 00000000000000..cbe202c4d23e0d
--- /dev/null
+++ b/flang/include/flang/Runtime/CUDA/registration.h
@@ -0,0 +1,28 @@
+//===-- include/flang/Runtime/CUDA/registration.h ---------------*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_RUNTIME_CUDA_REGISTRATION_H_
+#define FORTRAN_RUNTIME_CUDA_REGISTRATION_H_
+
+#include "flang/Runtime/entry-names.h"
+#include <cstddef>
+
+namespace Fortran::runtime::cuda {
+
+extern "C" {
+
+/// Register a CUDA module.
+void *RTDECL(CUFRegisterModule)(void *data);
+
+/// Register a device function.
+void RTDECL(CUFRegisterFunction)(void **module, const char *fct);
+
+} // extern "C"
+
+} // namespace Fortran::runtime::cuda
+#endif // FORTRAN_RUNTIME_CUDA_REGISTRATION_H_
diff --git a/flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt b/flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt
index b2221199995d58..5d4bd0785971f7 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt
+++ b/flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt
@@ -3,6 +3,7 @@ add_subdirectory(Attributes)
add_flang_library(CUFDialect
CUFDialect.cpp
CUFOps.cpp
+ CUFToLLVMIRTranslation.cpp
DEPENDS
MLIRIR
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
new file mode 100644
index 00000000000000..c6c9f96b811352
--- /dev/null
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
@@ -0,0 +1,104 @@
+//===- CUFToLLVMIRTranslation.cpp - Translate CUF dialect to LLVM IR ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the MLIR CUF dialect and LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h"
+#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
+#include "flang/Runtime/entry-names.h"
+#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/FormatVariadic.h"
+
+using namespace mlir;
+
+namespace {
+
+LogicalResult registerModule(cuf::RegisterModuleOp op,
+ llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ std::string binaryIdentifier =
+ op.getName().getLeafReference().str() + "_bin_cst";
+ llvm::Module *module = moduleTranslation.getLLVMModule();
+ llvm::Value *binary = module->getGlobalVariable(binaryIdentifier, true);
+ if (!binary)
+ return op.emitError() << "Couldn't find the binary: " << binaryIdentifier;
+
+ llvm::Type *ptrTy = builder.getPtrTy(0);
+ llvm::FunctionCallee fct = module->getOrInsertFunction(
+ RTNAME_STRING(CUFRegisterModule),
+ llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy}), false));
+ auto *handle = builder.CreateCall(fct, {binary});
+ moduleTranslation.mapValue(op->getResults().front()) = handle;
+ return mlir::success();
+}
+
+llvm::Value *getOrCreateFunctionName(llvm::Module *module,
+ llvm::IRBuilderBase &builder,
+ llvm::StringRef moduleName,
+ llvm::StringRef kernelName) {
+ std::string globalName =
+ std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, kernelName));
+
+ if (llvm::GlobalVariable *gv = module->getGlobalVariable(globalName))
+ return gv;
+
+ return builder.CreateGlobalString(kernelName, globalName);
+}
+
+LogicalResult registerKernel(cuf::RegisterKernelOp op,
+ llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::Module *module = moduleTranslation.getLLVMModule();
+ llvm::Type *ptrTy = builder.getPtrTy(0);
+ llvm::FunctionCallee fct = module->getOrInsertFunction(
+ RTNAME_STRING(CUFRegisterFunction),
+ llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy}),
+ false));
+ llvm::Value *modulePtr = moduleTranslation.lookupValue(op.getModulePtr());
+ builder.CreateCall(
+ fct, {modulePtr, getOrCreateFunctionName(module, builder,
+ op.getKernelModuleName().str(),
+ op.getKernelName().str())});
+ return mlir::success();
+}
+
+class CUFDialectLLVMIRTranslationInterface
+ : public LLVMTranslationDialectInterface {
+public:
+ using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
+
+ LogicalResult
+ convertOperation(Operation *operation, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) const override {
+ return llvm::TypeSwitch<Operation *, LogicalResult>(operation)
+ .Case([&](cuf::RegisterModuleOp op) {
+ return registerModule(op, builder, moduleTranslation);
+ })
+ .Case([&](cuf::RegisterKernelOp op) {
+ return registerKernel(op, builder, moduleTranslation);
+ })
+ .Default([&](Operation *op) {
+ return op->emitError("unsupported GPU operation: ") << op->getName();
+ });
+ }
+};
+
+} // namespace
+
+void cuf::registerCUFDialectTranslation(DialectRegistry ®istry) {
+ registry.insert<cuf::CUFDialect>();
+ registry.addExtension(+[](MLIRContext *ctx, cuf::CUFDialect *dialect) {
+ dialect->addInterfaces<CUFDialectLLVMIRTranslationInterface>();
+ });
+}
diff --git a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
index 91ef1259332de9..e81fafb529a27d 100644
--- a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
@@ -20,6 +20,7 @@
#include "flang/Runtime/CUDA/descriptor.h"
#include "flang/Runtime/CUDA/memory.h"
#include "flang/Runtime/allocatable.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/flang/runtime/CUDA/CMakeLists.txt b/flang/runtime/CUDA/CMakeLists.txt
index 193dd77e934558..86523b419f8711 100644
--- a/flang/runtime/CUDA/CMakeLists.txt
+++ b/flang/runtime/CUDA/CMakeLists.txt
@@ -18,6 +18,7 @@ add_flang_library(${CUFRT_LIBNAME}
allocatable.cpp
descriptor.cpp
memory.cpp
+ registration.cpp
)
if (BUILD_SHARED_LIBS)
diff --git a/flang/runtime/CUDA/registration.cpp b/flang/runtime/CUDA/registration.cpp
new file mode 100644
index 00000000000000..971192b16156be
--- /dev/null
+++ b/flang/runtime/CUDA/registration.cpp
@@ -0,0 +1,31 @@
+//===-- runtime/CUDA/registration.cpp -------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Runtime/CUDA/registration.h"
+
+#include "cuda_runtime.h"
+
+namespace Fortran::runtime::cuda {
+
+extern "C" {
+
+extern void **__cudaRegisterFatBinary(void *data);
+extern void __cudaRegisterFunction(void **fatCubinHandle, const char *hostFun,
+ char *deviceFun, const char *deviceName, int thread_limit, uint3 *tid,
+ uint3 *bid, dim3 *bDim, dim3 *gDim, int *wSize);
+
+void *RTDECL(CUFRegisterModule)(void *data) {
+ return __cudaRegisterFatBinary(data);
+}
+
+void RTDEF(CUFRegisterFunction)(void **module, const char *fct) {
+ __cudaRegisterFunction(module, fct, (char *)fct, fct, -1, (uint3 *)0,
+ (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
+}
+}
+} // namespace Fortran::runtime::cuda
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really good. Thank you!
Add LLVM IR Translation for
cuf.register_module
andcuf.register_kernel
. These are lowered to function call to the CUF runtime entries.