Skip to content

[CINN]support backend compiler to link multiple modules #65916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions paddle/cinn/backends/codegen_cuda_dev.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,6 @@ std::string CodeGenCUDA_Dev::Compile(const ir::Module &module,
if (output_kind == OutputKind::CHeader) {
GenerateHeaderFile(module);
} else if (output_kind == OutputKind::CImpl) {
PrintIncludes();

if (for_nvrtc_) {
str_ += "\nextern \"C\" {\n\n";
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/backends/codegen_invoke_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace backends {
/**
* CINN jit instructions support two kinds of invoke function, which can be
* represented like this:
* InvokeFunc = HostFunc | Switch<HostFunc>
* InvokeFunc = HostFunc | SwitchHostFunc
* HostFunc = X86Func | CudaHostFunc | HipHostFunc | ......
* CodeGenInvokeModule takes a CINN invoke Module(a module that only contains
* functions that jit instructions actually invoke) and output a LLVM module.
Expand Down Expand Up @@ -67,7 +67,7 @@ class CodeGenHost : public CodeGenInvokeModule {
};

/**
* In the Switch<HostFunc> pattern, InvokeFunc is a switch statement where
* In the SwitchHostFunc pattern, InvokeFunc is a switch statement where
* every case is a call of HostFunc. All the callee functions have the same
* parameters with the caller function.
*/
Expand Down
93 changes: 59 additions & 34 deletions paddle/cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,16 +232,24 @@ void Compiler::Build(const Module& module,
const bool end) {
auto PatternMatch = adt::match{
[&](common::UnknownArch) { CINN_NOT_IMPLEMENTED; },
[&](common::X86Arch) { CompileX86Module(module, end); },
[&](common::X86Arch) { CompileX86Module(module); },
[&](common::ARMArch) { CINN_NOT_IMPLEMENTED; },
[&](common::NVGPUArch) { CompileCudaModule(module, code, end); },
[&](common::HygonDCUArchHIP) { CompileHipModule(module, code, end); }};
return std::visit(PatternMatch, target_.arch.variant());
[&](common::NVGPUArch) { CompileCudaModule(module, code); },
[&](common::HygonDCUArchHIP) { CompileHipModule(module, code); }};
std::visit(PatternMatch, target_.arch.variant());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

顺手把这里的代码改成target_.arch.Match(...)的形式。

if (end) {
RegisterDeviceModuleSymbol();
engine_->AddSelfModule();
}
Comment on lines +240 to +243
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end是什么含义?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里沿袭了之前扩展x86时候的参数,语义是已经链接完所有的子Module,可以整合起来。
之后PR中会考虑商量修改一下这里的参数命名

}

void Compiler::AppendCX86(const Module& module) {
void Compiler::AppendCX86(const Module& module, const bool end) {
VLOG(3) << "Start Compiler::BuildCX86" << module;
CompileX86Module(module, true);
CompileX86Module(module);
if (end) {
RegisterDeviceModuleSymbol();
engine_->AddSelfModule();
}
VLOG(3) << "Over Compiler::BuildCX86";
}

Expand Down Expand Up @@ -296,9 +304,46 @@ std::string GetFileContent(const std::string& path) {
}
} // namespace

void Compiler::RegisterDeviceModuleSymbol() {
auto PatternMatch =
adt::match{[&](common::UnknownArch) { CINN_NOT_IMPLEMENTED; },
[&](common::X86Arch) { return; },
[&](common::ARMArch) { return; },
[&](common::NVGPUArch) { RegisterCudaModuleSymbol(); },
[&](common::HygonDCUArchHIP) { CINN_NOT_IMPLEMENTED; }};
return std::visit(PatternMatch, target_.arch.variant());
Comment on lines +308 to +314
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改一下代码,让它好看点:

return target_.arch.Match(
                 [&](common::UnknownArch) { CINN_NOT_IMPLEMENTED; },
                 [&](common::X86Arch) { return; },
                 [&](common::ARMArch) { return; },
                 [&](common::NVGPUArch) { RegisterCudaModuleSymbol(); },
                 [&](common::HygonDCUArchHIP) { CINN_NOT_IMPLEMENTED; });

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到,将在下一个step的PR修改这里的写法

}

void Compiler::RegisterCudaModuleSymbol() {
#ifdef CINN_WITH_CUDA
nvrtc::Compiler compiler;
std::string source_code =
CodeGenCUDA_Dev::GetSourceHeader() + device_fn_code_;
auto ptx = compiler(source_code);
CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n"
<< source_code;
using runtime::cuda::CUDAModule;
cuda_module_.reset(new CUDAModule(ptx,
compiler.compile_to_cubin()
? CUDAModule::Kind::CUBIN
: CUDAModule::Kind::PTX));

RuntimeSymbols symbols;
for (const auto& kernel_fn_name : device_fn_name_) {
auto fn_kernel = cuda_module_->GetFunction(kernel_fn_name);
CHECK(fn_kernel) << "Fail to get CUfunction kernel_fn_name";
fn_ptr_.push_back(reinterpret_cast<void*>(fn_kernel));
symbols.RegisterVar(kernel_fn_name + "_ptr_",
reinterpret_cast<void*>(fn_kernel));
}
engine_->RegisterModuleRuntimeSymbols(std::move(symbols));
#else
CINN_NOT_IMPLEMENTED
#endif
}

void Compiler::CompileCudaModule(const Module& module,
const std::string& code,
bool add_module) {
const std::string& code) {
#ifdef CINN_WITH_CUDA
auto _host_module_device_module_ =
SplitDeviceAndHostModule(module); // NOLINT
Expand All @@ -324,46 +369,26 @@ void Compiler::CompileCudaModule(const Module& module,
<< device_module;
VLOG(3) << "[CUDA] C:\n" << source_code;
SourceCodePrint::GetInstance()->write(source_code);
using runtime::cuda::CUDAModule;
device_fn_code_ += source_code;

nvrtc::Compiler compiler;
auto ptx = compiler(source_code);
CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n"
<< source_code;
cuda_module_.reset(new CUDAModule(ptx,
compiler.compile_to_cubin()
? CUDAModule::Kind::CUBIN
: CUDAModule::Kind::PTX));

RuntimeSymbols symbols;
for (auto& fn : device_module.functions()) {
std::string kernel_fn_name = fn->name;
auto fn_kernel = cuda_module_->GetFunction(kernel_fn_name);
CHECK(fn_kernel);

fn_ptr_.push_back(reinterpret_cast<void*>(fn_kernel));

symbols.RegisterVar(kernel_fn_name + "_ptr_",
reinterpret_cast<void*>(fn_kernel));
device_fn_name_.emplace_back(kernel_fn_name);
}

engine_ = ExecutionEngine::Create(ExecutionOptions(), std::move(symbols));
engine_->Link<CodeGenCUDA_Host>(host_module, add_module);
engine_->Link<CodeGenCUDA_Host>(host_module);

#else
CINN_NOT_IMPLEMENTED
#endif
}

void Compiler::CompileHipModule(const Module& module,
const std::string& code,
bool add_module) {
void Compiler::CompileHipModule(const Module& module, const std::string& code) {
PADDLE_THROW(
phi::errors::Unimplemented("CINN todo: new hardware HygonDCUArchHIP"));
}

void Compiler::CompileX86Module(const Module& module, bool add_module) {
engine_->Link<CodeGenX86>(module, add_module);
void Compiler::CompileX86Module(const Module& module) {
engine_->Link<CodeGenX86>(module);
}

void Compiler::ExportObject(const std::string& path) {
Expand Down
19 changes: 12 additions & 7 deletions paddle/cinn/backends/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class Compiler final {
void Build(const ir::Module& module,
const std::string& code = "",
const bool end = true);
void AppendCX86(const ir::Module& module);
void AppendCX86(const ir::Module& module, const bool end = true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最好不要使用bool值做参数,语义太弱,扩展性太差。
我看这个函数也不大,直接写成两个函数多好啊。就算调用处的就代码已经拿bool值做参数了,也不是新代码继续这么做的理由


void ExportObject(const std::string& path);

Expand All @@ -123,15 +123,17 @@ class Compiler final {
std::vector<void*> GetFnPtr() const { return fn_ptr_; }

private:
// do not register device symbol until end=true for build fucntion
void RegisterDeviceModuleSymbol();

void RegisterCudaModuleSymbol();

void CompileCudaModule(const ir::Module& module,
const std::string& code = "",
bool add_module = true);
const std::string& code = "");

void CompileHipModule(const ir::Module& module,
const std::string& code = "",
bool add_module = true);
void CompileHipModule(const ir::Module& module, const std::string& code = "");

void CompileX86Module(const ir::Module& module, bool add_module = true);
void CompileX86Module(const ir::Module& module);

explicit Compiler(const Target& target)
: target_(target), engine_(ExecutionEngine::Create(ExecutionOptions())) {}
Expand All @@ -143,6 +145,9 @@ class Compiler final {
std::unique_ptr<ExecutionEngine> engine_;

std::vector<void*> fn_ptr_;
// only heterogeneous systems need to record device func and module
std::vector<std::string> device_fn_name_;
std::string device_fn_code_;
#ifdef CINN_WITH_CUDA
std::unique_ptr<runtime::cuda::CUDAModule> cuda_module_;
#endif
Expand Down
51 changes: 23 additions & 28 deletions paddle/cinn/backends/llvm/execution_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,6 @@ std::unique_ptr<llvm::MemoryBuffer> NaiveObjectCache::getObject(

/*static*/ std::unique_ptr<ExecutionEngine> ExecutionEngine::Create(
const ExecutionOptions &config) {
return Create(config, {});
}

/*static*/ std::unique_ptr<ExecutionEngine> ExecutionEngine::Create(
const ExecutionOptions &config, RuntimeSymbols &&module_symbols) {
VLOG(1) << "===================== Create CINN ExecutionEngine begin "
"====================";
VLOG(1) << "initialize llvm config";
Expand All @@ -123,8 +118,7 @@ std::unique_ptr<llvm::MemoryBuffer> NaiveObjectCache::getObject(
static std::once_flag flag;
std::call_once(flag, InitializeLLVMPasses);

auto engine = std::make_unique<ExecutionEngine>(/*enable_object_cache=*/true,
std::move(module_symbols));
auto engine = std::make_unique<ExecutionEngine>(/*enable_object_cache=*/true);

auto compile_layer_creator =
[&engine](llvm::orc::JITTargetMachineBuilder jtmb)
Expand Down Expand Up @@ -160,9 +154,9 @@ std::unique_ptr<llvm::MemoryBuffer> NaiveObjectCache::getObject(
llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
engine->jit_->getDataLayout().getGlobalPrefix())));

VLOG(2) << "register runtime call symbols";
VLOG(2) << "register global runtime call symbols";

engine->RegisterRuntimeSymbols();
engine->RegisterGlobalRuntimeSymbols();

VLOG(2) << "===================== Create CINN ExecutionEngine end "
"====================";
Expand All @@ -176,7 +170,7 @@ std::unique_ptr<llvm::MemoryBuffer> NaiveObjectCache::getObject(
}

template <typename CodeGenT>
void ExecutionEngine::Link(const ir::Module &module, bool add_module) {
void ExecutionEngine::Link(const ir::Module &module) {
utils::RecordEvent("ExecutionEngine Link", utils::EventType::kOrdinary);

auto ir_emitter = std::make_unique<CodeGenT>(m.get(), b.get());
Expand All @@ -202,10 +196,6 @@ void ExecutionEngine::Link(const ir::Module &module, bool add_module) {
pass_manager, rawstream, nullptr, llvm::CGFT_ObjectFile);
pass_manager.run(*m);

if (add_module) {
AddSelfModule();
}

if (VLOG_IS_ON(5)) {
VLOG(5) << "======= dump jit execution session ======";
std::string buffer;
Expand Down Expand Up @@ -235,6 +225,20 @@ bool ExecutionEngine::AddModule(std::unique_ptr<llvm::Module> module,
llvm::cantFail(jit_->addIRModule(std::move(tsm)));
return true;
}

void ExecutionEngine::RegisterModuleRuntimeSymbols(
RuntimeSymbols &&module_symbols) {
module_symbols_ = std::forward<RuntimeSymbols>(module_symbols);
auto *session = &jit_->getExecutionSession();
for (const auto &sym : module_symbols_.All()) {
VLOG(0) << "Add symbol: {" << sym.first << ":" << sym.second << "}";
llvm::cantFail(jit_->define(llvm::orc::absoluteSymbols(
{{session->intern(sym.first),
{llvm::pointerToJITTargetAddress(sym.second),
llvm::JITSymbolFlags::Exported}}})));
}
}

bool ExecutionEngine::AddSelfModule() {
return AddModule(std::move(m), std::move(ctx));
}
Expand All @@ -256,8 +260,8 @@ void *ExecutionEngine::Lookup(absl::string_view name) {
return nullptr;
}

void ExecutionEngine::RegisterRuntimeSymbols() {
utils::RecordEvent("ExecutionEngine RegisterRuntimeSymbols",
void ExecutionEngine::RegisterGlobalRuntimeSymbols() {
utils::RecordEvent("ExecutionEngine RegisterGlobalRuntimeSymbols",
utils::EventType::kOrdinary);
const auto &registry = GlobalSymbolRegistry::Global();
auto *session = &jit_->getExecutionSession();
Expand All @@ -267,19 +271,10 @@ void ExecutionEngine::RegisterRuntimeSymbols() {
{llvm::pointerToJITTargetAddress(sym.second),
llvm::JITSymbolFlags::None}}})));
}
for (const auto &sym : module_symbols_.All()) {
llvm::cantFail(jit_->define(llvm::orc::absoluteSymbols(
{{session->intern(sym.first),
{llvm::pointerToJITTargetAddress(sym.second),
llvm::JITSymbolFlags::None}}})));
}
}

template void ExecutionEngine::Link<CodeGenLLVM>(const ir::Module &module,
bool add_module);
template void ExecutionEngine::Link<CodeGenX86>(const ir::Module &module,
bool add_module);
template void ExecutionEngine::Link<CodeGenCUDA_Host>(const ir::Module &module,
bool add_module);
template void ExecutionEngine::Link<CodeGenLLVM>(const ir::Module &module);
template void ExecutionEngine::Link<CodeGenX86>(const ir::Module &module);
template void ExecutionEngine::Link<CodeGenCUDA_Host>(const ir::Module &module);

} // namespace cinn::backends
15 changes: 6 additions & 9 deletions paddle/cinn/backends/llvm/execution_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,36 +73,33 @@ class ExecutionEngine {
static std::unique_ptr<ExecutionEngine> Create(
const ExecutionOptions &config);

static std::unique_ptr<ExecutionEngine> Create(
const ExecutionOptions &config, RuntimeSymbols &&module_symbols);

void *Lookup(absl::string_view name);

template <typename CodeGenT = CodeGenLLVM>
void Link(const ir::Module &module, bool add_module = true);
void Link(const ir::Module &module);

void ExportObject(const std::string &path);

bool AddModule(std::unique_ptr<llvm::Module> module,
std::unique_ptr<llvm::LLVMContext> context);

void RegisterModuleRuntimeSymbols(RuntimeSymbols &&module_symbols);

bool AddSelfModule();

protected:
explicit ExecutionEngine(bool enable_object_cache,
RuntimeSymbols &&module_symbols)
explicit ExecutionEngine(bool enable_object_cache)
: cache_(std::make_unique<NaiveObjectCache>()),
module_symbols_(std::move(module_symbols)),
ctx(std::make_unique<llvm::LLVMContext>()),
b(std::make_unique<llvm::IRBuilder<>>(*ctx)) {}

void RegisterRuntimeSymbols();
void RegisterGlobalRuntimeSymbols();

bool SetupTargetTriple(llvm::Module *module);

// This may not be a compatible implementation.
friend std::unique_ptr<ExecutionEngine> std::make_unique<ExecutionEngine>(
bool &&, cinn::backends::RuntimeSymbols &&);
bool &&);

private:
mutable std::mutex mu_;
Expand Down
8 changes: 8 additions & 0 deletions paddle/cinn/backends/llvm/runtime_symbol_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ class RuntimeSymbols {
scalar_holder_ = std::move(rhs.scalar_holder_);
}

RuntimeSymbols &operator=(RuntimeSymbols &&rhs) noexcept {
if (this != &rhs) {
symbols_ = std::move(rhs.symbols_);
scalar_holder_ = std::move(rhs.scalar_holder_);
}
return *this;
}

/**
* Register function address.
* @param name Name of the symbol.
Expand Down
5 changes: 1 addition & 4 deletions paddle/cinn/pybind/backends.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,7 @@ void BindExecutionEngine(py::module *m) {
&ExecutionEngine::Create)),
py::arg("options") = ExecutionOptions())
.def("lookup", lookup)
.def("link",
&ExecutionEngine::Link,
py::arg("module"),
py::arg("add_module") = true);
.def("link", &ExecutionEngine::Link, py::arg("module"));

{
auto lookup = [](Compiler &self, absl::string_view name) {
Expand Down