-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[CINN]support backend compiler to link multiple modules #65916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -232,16 +232,24 @@ void Compiler::Build(const Module& module, | |
const bool end) { | ||
auto PatternMatch = adt::match{ | ||
[&](common::UnknownArch) { CINN_NOT_IMPLEMENTED; }, | ||
[&](common::X86Arch) { CompileX86Module(module, end); }, | ||
[&](common::X86Arch) { CompileX86Module(module); }, | ||
[&](common::ARMArch) { CINN_NOT_IMPLEMENTED; }, | ||
[&](common::NVGPUArch) { CompileCudaModule(module, code, end); }, | ||
[&](common::HygonDCUArchHIP) { CompileHipModule(module, code, end); }}; | ||
return std::visit(PatternMatch, target_.arch.variant()); | ||
[&](common::NVGPUArch) { CompileCudaModule(module, code); }, | ||
[&](common::HygonDCUArchHIP) { CompileHipModule(module, code); }}; | ||
std::visit(PatternMatch, target_.arch.variant()); | ||
if (end) { | ||
RegisterDeviceModuleSymbol(); | ||
engine_->AddSelfModule(); | ||
} | ||
Comment on lines
+240
to
+243
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. end是什么含义? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里沿袭了之前扩展x86时候的参数,语义是已经链接完所有的子Module,可以整合起来。 |
||
} | ||
|
||
void Compiler::AppendCX86(const Module& module) { | ||
void Compiler::AppendCX86(const Module& module, const bool end) { | ||
VLOG(3) << "Start Compiler::BuildCX86" << module; | ||
CompileX86Module(module, true); | ||
CompileX86Module(module); | ||
if (end) { | ||
RegisterDeviceModuleSymbol(); | ||
engine_->AddSelfModule(); | ||
} | ||
VLOG(3) << "Over Compiler::BuildCX86"; | ||
} | ||
|
||
|
@@ -296,9 +304,46 @@ std::string GetFileContent(const std::string& path) { | |
} | ||
} // namespace | ||
|
||
void Compiler::RegisterDeviceModuleSymbol() { | ||
auto PatternMatch = | ||
adt::match{[&](common::UnknownArch) { CINN_NOT_IMPLEMENTED; }, | ||
[&](common::X86Arch) { return; }, | ||
[&](common::ARMArch) { return; }, | ||
[&](common::NVGPUArch) { RegisterCudaModuleSymbol(); }, | ||
[&](common::HygonDCUArchHIP) { CINN_NOT_IMPLEMENTED; }}; | ||
return std::visit(PatternMatch, target_.arch.variant()); | ||
Comment on lines
+308
to
+314
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 改一下代码,让它好看点: return target_.arch.Match(
[&](common::UnknownArch) { CINN_NOT_IMPLEMENTED; },
[&](common::X86Arch) { return; },
[&](common::ARMArch) { return; },
[&](common::NVGPUArch) { RegisterCudaModuleSymbol(); },
[&](common::HygonDCUArchHIP) { CINN_NOT_IMPLEMENTED; }); There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 收到,将在下一个step的PR修改这里的写法 |
||
} | ||
|
||
void Compiler::RegisterCudaModuleSymbol() { | ||
#ifdef CINN_WITH_CUDA | ||
nvrtc::Compiler compiler; | ||
std::string source_code = | ||
CodeGenCUDA_Dev::GetSourceHeader() + device_fn_code_; | ||
auto ptx = compiler(source_code); | ||
CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" | ||
<< source_code; | ||
using runtime::cuda::CUDAModule; | ||
cuda_module_.reset(new CUDAModule(ptx, | ||
compiler.compile_to_cubin() | ||
? CUDAModule::Kind::CUBIN | ||
: CUDAModule::Kind::PTX)); | ||
|
||
RuntimeSymbols symbols; | ||
for (const auto& kernel_fn_name : device_fn_name_) { | ||
auto fn_kernel = cuda_module_->GetFunction(kernel_fn_name); | ||
CHECK(fn_kernel) << "Fail to get CUfunction kernel_fn_name"; | ||
fn_ptr_.push_back(reinterpret_cast<void*>(fn_kernel)); | ||
symbols.RegisterVar(kernel_fn_name + "_ptr_", | ||
reinterpret_cast<void*>(fn_kernel)); | ||
} | ||
engine_->RegisterModuleRuntimeSymbols(std::move(symbols)); | ||
#else | ||
CINN_NOT_IMPLEMENTED | ||
#endif | ||
} | ||
|
||
void Compiler::CompileCudaModule(const Module& module, | ||
const std::string& code, | ||
bool add_module) { | ||
const std::string& code) { | ||
#ifdef CINN_WITH_CUDA | ||
auto _host_module_device_module_ = | ||
SplitDeviceAndHostModule(module); // NOLINT | ||
|
@@ -324,46 +369,26 @@ void Compiler::CompileCudaModule(const Module& module, | |
<< device_module; | ||
VLOG(3) << "[CUDA] C:\n" << source_code; | ||
SourceCodePrint::GetInstance()->write(source_code); | ||
using runtime::cuda::CUDAModule; | ||
device_fn_code_ += source_code; | ||
|
||
nvrtc::Compiler compiler; | ||
auto ptx = compiler(source_code); | ||
CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" | ||
<< source_code; | ||
cuda_module_.reset(new CUDAModule(ptx, | ||
compiler.compile_to_cubin() | ||
? CUDAModule::Kind::CUBIN | ||
: CUDAModule::Kind::PTX)); | ||
|
||
RuntimeSymbols symbols; | ||
for (auto& fn : device_module.functions()) { | ||
std::string kernel_fn_name = fn->name; | ||
auto fn_kernel = cuda_module_->GetFunction(kernel_fn_name); | ||
CHECK(fn_kernel); | ||
|
||
fn_ptr_.push_back(reinterpret_cast<void*>(fn_kernel)); | ||
|
||
symbols.RegisterVar(kernel_fn_name + "_ptr_", | ||
reinterpret_cast<void*>(fn_kernel)); | ||
device_fn_name_.emplace_back(kernel_fn_name); | ||
} | ||
|
||
engine_ = ExecutionEngine::Create(ExecutionOptions(), std::move(symbols)); | ||
engine_->Link<CodeGenCUDA_Host>(host_module, add_module); | ||
engine_->Link<CodeGenCUDA_Host>(host_module); | ||
|
||
#else | ||
CINN_NOT_IMPLEMENTED | ||
#endif | ||
} | ||
|
||
void Compiler::CompileHipModule(const Module& module, | ||
const std::string& code, | ||
bool add_module) { | ||
void Compiler::CompileHipModule(const Module& module, const std::string& code) { | ||
PADDLE_THROW( | ||
phi::errors::Unimplemented("CINN todo: new hardware HygonDCUArchHIP")); | ||
} | ||
|
||
void Compiler::CompileX86Module(const Module& module, bool add_module) { | ||
engine_->Link<CodeGenX86>(module, add_module); | ||
void Compiler::CompileX86Module(const Module& module) { | ||
engine_->Link<CodeGenX86>(module); | ||
} | ||
|
||
void Compiler::ExportObject(const std::string& path) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -106,7 +106,7 @@ class Compiler final { | |
void Build(const ir::Module& module, | ||
const std::string& code = "", | ||
const bool end = true); | ||
void AppendCX86(const ir::Module& module); | ||
void AppendCX86(const ir::Module& module, const bool end = true); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 最好不要使用bool值做参数,语义太弱,扩展性太差。 |
||
|
||
void ExportObject(const std::string& path); | ||
|
||
|
@@ -123,15 +123,17 @@ class Compiler final { | |
std::vector<void*> GetFnPtr() const { return fn_ptr_; } | ||
|
||
private: | ||
// do not register device symbol until end=true for build fucntion | ||
void RegisterDeviceModuleSymbol(); | ||
|
||
void RegisterCudaModuleSymbol(); | ||
|
||
void CompileCudaModule(const ir::Module& module, | ||
const std::string& code = "", | ||
bool add_module = true); | ||
const std::string& code = ""); | ||
|
||
void CompileHipModule(const ir::Module& module, | ||
const std::string& code = "", | ||
bool add_module = true); | ||
void CompileHipModule(const ir::Module& module, const std::string& code = ""); | ||
|
||
void CompileX86Module(const ir::Module& module, bool add_module = true); | ||
void CompileX86Module(const ir::Module& module); | ||
|
||
explicit Compiler(const Target& target) | ||
: target_(target), engine_(ExecutionEngine::Create(ExecutionOptions())) {} | ||
|
@@ -143,6 +145,9 @@ class Compiler final { | |
std::unique_ptr<ExecutionEngine> engine_; | ||
|
||
std::vector<void*> fn_ptr_; | ||
// only heterogeneous systems need to record device func and module | ||
std::vector<std::string> device_fn_name_; | ||
std::string device_fn_code_; | ||
#ifdef CINN_WITH_CUDA | ||
std::unique_ptr<runtime::cuda::CUDAModule> cuda_module_; | ||
#endif | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
顺手把这里的代码改成target_.arch.Match(...)的形式。