Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][cuda] Translate cuf.register_kernel and cuf.register_module #112972

Merged
merged 3 commits into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h"
#include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/OpDefinition.h"

#define GET_OP_CLASSES
Expand Down
20 changes: 18 additions & 2 deletions flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ include "flang/Optimizer/Dialect/CUF/CUFDialect.td"
include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.td"
include "flang/Optimizer/Dialect/FIRTypes.td"
include "flang/Optimizer/Dialect/FIRAttr.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/BuiltinAttributes.td"

Expand Down Expand Up @@ -288,15 +289,30 @@ def cuf_KernelOp : cuf_Op<"kernel", [AttrSizedOperandSegments,
let hasVerifier = 1;
}

def cuf_RegisterModuleOp : cuf_Op<"register_module", []> {
let summary = "Register a CUDA module";

let arguments = (ins
SymbolRefAttr:$name
);

let assemblyFormat = [{
$name attr-dict `->` type($modulePtr)
}];

let results = (outs LLVM_AnyPointer:$modulePtr);
}

def cuf_RegisterKernelOp : cuf_Op<"register_kernel", []> {
let summary = "Register a CUDA kernel";

let arguments = (ins
SymbolRefAttr:$name
SymbolRefAttr:$name,
LLVM_AnyPointer:$modulePtr
);

let assemblyFormat = [{
$name attr-dict
$name `(` $modulePtr `:` type($modulePtr) `)`attr-dict
}];

let hasVerifier = 1;
Expand Down
29 changes: 29 additions & 0 deletions flang/include/flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
//===- CUFToLLVMIRTranslation.h - CUF Dialect to LLVM IR --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This provides registration calls for GPU dialect to LLVM IR translation.
//
//===----------------------------------------------------------------------===//

#ifndef FLANG_OPTIMIZER_DIALECT_CUF_GPUTOLLVMIRTRANSLATION_H_
#define FLANG_OPTIMIZER_DIALECT_CUF_GPUTOLLVMIRTRANSLATION_H_

namespace mlir {
class DialectRegistry;
class MLIRContext;
} // namespace mlir

namespace cuf {

/// Register the CUF dialect and the translation from it to the LLVM IR in
/// the given registry.
void registerCUFDialectTranslation(mlir::DialectRegistry &registry);

} // namespace cuf

#endif // FLANG_OPTIMIZER_DIALECT_CUF_GPUTOLLVMIRTRANSLATION_H_
2 changes: 2 additions & 0 deletions flang/include/flang/Optimizer/Support/InitFIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define FORTRAN_OPTIMIZER_SUPPORT_INITFIR_H

#include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
#include "flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "mlir/Conversion/Passes.h"
Expand Down Expand Up @@ -61,6 +62,7 @@ inline void addFIRExtensions(mlir::DialectRegistry &registry,
if (addFIRInlinerInterface)
addFIRInlinerExtension(registry);
addFIRToLLVMIRExtension(registry);
cuf::registerCUFDialectTranslation(registry);
}

inline void loadNonCodegenDialects(mlir::MLIRContext &context) {
Expand Down
28 changes: 28 additions & 0 deletions flang/include/flang/Runtime/CUDA/registration.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//===-- include/flang/Runtime/CUDA/registration.h ---------------*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef FORTRAN_RUNTIME_CUDA_REGISTRATION_H_
#define FORTRAN_RUNTIME_CUDA_REGISTRATION_H_

#include "flang/Runtime/entry-names.h"
#include <cstddef>

namespace Fortran::runtime::cuda {

extern "C" {

/// Register a CUDA module.
void *RTDECL(CUFRegisterModule)(void *data);

/// Register a device function.
void RTDECL(CUFRegisterFunction)(void **module, const char *fct);

} // extern "C"

} // namespace Fortran::runtime::cuda
#endif // FORTRAN_RUNTIME_CUDA_REGISTRATION_H_
1 change: 1 addition & 0 deletions flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_subdirectory(Attributes)
add_flang_library(CUFDialect
CUFDialect.cpp
CUFOps.cpp
CUFToLLVMIRTranslation.cpp

DEPENDS
MLIRIR
Expand Down
104 changes: 104 additions & 0 deletions flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
//===- CUFToLLVMIRTranslation.cpp - Translate CUF dialect to LLVM IR ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a translation between the MLIR CUF dialect and LLVM IR.
//
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h"
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
#include "flang/Runtime/entry-names.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/FormatVariadic.h"

using namespace mlir;

namespace {

LogicalResult registerModule(cuf::RegisterModuleOp op,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
std::string binaryIdentifier =
op.getName().getLeafReference().str() + "_bin_cst";
llvm::Module *module = moduleTranslation.getLLVMModule();
llvm::Value *binary = module->getGlobalVariable(binaryIdentifier, true);
if (!binary)
return op.emitError() << "Couldn't find the binary: " << binaryIdentifier;

llvm::Type *ptrTy = builder.getPtrTy(0);
llvm::FunctionCallee fct = module->getOrInsertFunction(
RTNAME_STRING(CUFRegisterModule),
llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy}), false));
auto *handle = builder.CreateCall(fct, {binary});
moduleTranslation.mapValue(op->getResults().front()) = handle;
return mlir::success();
}

llvm::Value *getOrCreateFunctionName(llvm::Module *module,
llvm::IRBuilderBase &builder,
llvm::StringRef moduleName,
llvm::StringRef kernelName) {
std::string globalName =
std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, kernelName));

if (llvm::GlobalVariable *gv = module->getGlobalVariable(globalName))
return gv;

return builder.CreateGlobalString(kernelName, globalName);
}

LogicalResult registerKernel(cuf::RegisterKernelOp op,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::Module *module = moduleTranslation.getLLVMModule();
llvm::Type *ptrTy = builder.getPtrTy(0);
llvm::FunctionCallee fct = module->getOrInsertFunction(
RTNAME_STRING(CUFRegisterFunction),
llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy}),
false));
llvm::Value *modulePtr = moduleTranslation.lookupValue(op.getModulePtr());
builder.CreateCall(
fct, {modulePtr, getOrCreateFunctionName(module, builder,
op.getKernelModuleName().str(),
op.getKernelName().str())});
return mlir::success();
}

class CUFDialectLLVMIRTranslationInterface
: public LLVMTranslationDialectInterface {
public:
using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;

LogicalResult
convertOperation(Operation *operation, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) const override {
return llvm::TypeSwitch<Operation *, LogicalResult>(operation)
.Case([&](cuf::RegisterModuleOp op) {
return registerModule(op, builder, moduleTranslation);
})
.Case([&](cuf::RegisterKernelOp op) {
return registerKernel(op, builder, moduleTranslation);
})
.Default([&](Operation *op) {
return op->emitError("unsupported GPU operation: ") << op->getName();
});
}
};

} // namespace

void cuf::registerCUFDialectTranslation(DialectRegistry &registry) {
registry.insert<cuf::CUFDialect>();
registry.addExtension(+[](MLIRContext *ctx, cuf::CUFDialect *dialect) {
dialect->addInterfaces<CUFDialectLLVMIRTranslationInterface>();
});
}
5 changes: 4 additions & 1 deletion flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,15 @@ struct CUFAddConstructor
// Register kernels
auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaModName);
if (gpuMod) {
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(ctx);
auto registeredMod = builder.create<cuf::RegisterModuleOp>(
loc, llvmPtrTy, mlir::SymbolRefAttr::get(ctx, gpuMod.getName()));
for (auto func : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) {
if (func.isKernel()) {
auto kernelName = mlir::SymbolRefAttr::get(
builder.getStringAttr(cudaModName),
{mlir::SymbolRefAttr::get(builder.getContext(), func.getName())});
builder.create<cuf::RegisterKernelOp>(loc, kernelName);
builder.create<cuf::RegisterKernelOp>(loc, kernelName, registeredMod);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/Transforms/CufOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "flang/Runtime/CUDA/descriptor.h"
#include "flang/Runtime/CUDA/memory.h"
#include "flang/Runtime/allocatable.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down
1 change: 1 addition & 0 deletions flang/runtime/CUDA/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_flang_library(${CUFRT_LIBNAME}
allocatable.cpp
descriptor.cpp
memory.cpp
registration.cpp
)

if (BUILD_SHARED_LIBS)
Expand Down
31 changes: 31 additions & 0 deletions flang/runtime/CUDA/registration.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//===-- runtime/CUDA/registration.cpp -------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Runtime/CUDA/registration.h"

#include "cuda_runtime.h"

namespace Fortran::runtime::cuda {

extern "C" {

extern void **__cudaRegisterFatBinary(void *data);
extern void __cudaRegisterFunction(void **fatCubinHandle, const char *hostFun,
char *deviceFun, const char *deviceName, int thread_limit, uint3 *tid,
uint3 *bid, dim3 *bDim, dim3 *gDim, int *wSize);

void *RTDECL(CUFRegisterModule)(void *data) {
return __cudaRegisterFatBinary(data);
}

void RTDEF(CUFRegisterFunction)(void **module, const char *fct) {
__cudaRegisterFunction(module, fct, (char *)fct, fct, -1, (uint3 *)0,
(uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
}
}
} // namespace Fortran::runtime::cuda
5 changes: 3 additions & 2 deletions flang/test/Fir/CUDA/cuda-register-func.fir
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ module attributes {gpu.container_module} {
}

// CHECK-LABEL: llvm.func internal @__cudaFortranConstructor()
// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device1
// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device2
// CHECK: %[[MOD_HANDLE:.*]] = cuf.register_module @cuda_device_mod -> !llvm.ptr
// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%[[MOD_HANDLE]] : !llvm.ptr)
// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device2(%[[MOD_HANDLE]] : !llvm.ptr)
15 changes: 10 additions & 5 deletions flang/test/Fir/cuf-invalid.fir
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@ module attributes {gpu.container_module} {
}
}
llvm.func internal @__cudaFortranConstructor() {
%0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
// expected-error@+1{{'cuf.register_kernel' op only kernel gpu.func can be registered}}
cuf.register_kernel @cuda_device_mod::@_QPsub_device1
cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%0 : !llvm.ptr)
llvm.return
}
}
Expand All @@ -150,8 +151,9 @@ module attributes {gpu.container_module} {
}
}
llvm.func internal @__cudaFortranConstructor() {
%0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
// expected-error@+1{{'cuf.register_kernel' op device function not found}}
cuf.register_kernel @cuda_device_mod::@_QPsub_device2
cuf.register_kernel @cuda_device_mod::@_QPsub_device2(%0 : !llvm.ptr)
llvm.return
}
}
Expand All @@ -160,8 +162,9 @@ module attributes {gpu.container_module} {

module attributes {gpu.container_module} {
llvm.func internal @__cudaFortranConstructor() {
%0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
// expected-error@+1{{'cuf.register_kernel' op gpu module not found}}
cuf.register_kernel @cuda_device_mod::@_QPsub_device1
cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%0 : !llvm.ptr)
llvm.return
}
}
Expand All @@ -170,8 +173,9 @@ module attributes {gpu.container_module} {

module attributes {gpu.container_module} {
llvm.func internal @__cudaFortranConstructor() {
%0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
// expected-error@+1{{'cuf.register_kernel' op expect a module and a kernel name}}
cuf.register_kernel @_QPsub_device1
cuf.register_kernel @_QPsub_device1(%0 : !llvm.ptr)
llvm.return
}
}
Expand All @@ -185,8 +189,9 @@ module attributes {gpu.container_module} {
}
}
llvm.func internal @__cudaFortranConstructor() {
%0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
// expected-error@+1{{'cuf.register_kernel' op only gpu.kernel llvm.func can be registered}}
cuf.register_kernel @cuda_device_mod::@_QPsub_device1
cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%0 : !llvm.ptr)
llvm.return
}
}
Loading