Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NewIR]Gen python c apis for new ir #56571

Merged
merged 7 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
… python-c-gen
  • Loading branch information
0x45f committed Aug 24, 2023
commit 00e6e8cad60733ea11c69b6f31afcfa282e64aa3
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ set(XPU_XFT_LIB_NAME "libxft.so")
set(XPU_XPTI_LIB_NAME "libxpti.so")

if(NOT DEFINED XPU_BASE_DATE)
set(XPU_BASE_DATE "20230819")
set(XPU_BASE_DATE "20230823")
endif()
set(XPU_XCCL_BASE_VERSION "1.0.53.6")
if(NOT DEFINED XPU_XFT_BASE_VERSION)
Expand Down
26 changes: 25 additions & 1 deletion cmake/generic.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,19 @@ function(find_fluid_modules TARGET_NAME)
endif()
endfunction()

# NOTE(Aurelius84): NOT_INFER_MODULES is used to tag
# and not considered as DEPS for inference libs.
set_property(GLOBAL PROPERTY NOT_INFER_MODULES "")

function(ignore_infer_modules TARGET_NAME)
get_property(not_infer_modules GLOBAL PROPERTY NOT_INFER_MODULES)
list(FIND not_infer_modules TARGET_NAME is_found)
if(is_found EQUAL -1) # NOT FOUND
set(not_infer_modules ${not_infer_modules} ${TARGET_NAME})
set_property(GLOBAL PROPERTY NOT_INFER_MODULES "${not_infer_modules}")
endif()
endfunction()

set_property(GLOBAL PROPERTY PHI_MODULES "")
# find all phi modules is used for paddle static library
# for building inference libs
Expand Down Expand Up @@ -335,7 +348,15 @@ function(check_coverage_opt TARGET_NAME SRCS)
endfunction()

function(cc_library TARGET_NAME)
set(options STATIC static SHARED shared INTERFACE interface)
set(options
STATIC
static
SHARED
shared
INTERFACE
interface
NOT_FOR_INFER
not_for_infer)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(cc_library "${options}" "${oneValueArgs}"
Expand All @@ -347,6 +368,9 @@ function(cc_library TARGET_NAME)
CACHE STRING "output library name for target ${TARGET_NAME}")
endif()
if(cc_library_SRCS)
if(cc_library_NOT_FOR_INFER OR cc_library_not_for_infer)
ignore_infer_modules(${TARGET_NAME})
endif()
if(cc_library_SHARED OR cc_library_shared) # build *.so
add_library(${TARGET_NAME} SHARED ${cc_library_SRCS})
elseif(cc_library_INTERFACE OR cc_library_interface)
Expand Down
42 changes: 42 additions & 0 deletions paddle/cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,48 @@ using ir::Module;

static constexpr int DebugLogMaxLen = 30000;

void CompilationInfoDumper::DumpLoweredFuncByGroupIndex(
const ir::LoweredFunc& lowered_func, const int gidx) {
if (FLAGS_cinn_dump_group_lowered_func.empty() ||
lowered_func.get() == nullptr) {
return;
}
std::stringstream content;
content << lowered_func;
Dump(FLAGS_cinn_dump_group_lowered_func,
gidx,
"lowered_function.txt",
content.str());
}

void CompilationInfoDumper::DumpSourceCodeByGroupIndex(
const std::string& source_code, const int gidx) {
if (FLAGS_cinn_dump_group_source_code.empty()) {
return;
}
Dump(FLAGS_cinn_dump_group_source_code, gidx, "source_code.cu", source_code);
}

void CompilationInfoDumper::DumpPtxCodeByGroupIndex(
const std::string& source_ptx, const int gidx) {
if (FLAGS_cinn_dump_group_ptx.empty()) {
return;
}
Dump(FLAGS_cinn_dump_group_ptx, gidx, "source_ptx.ptx", source_ptx);
}

void CompilationInfoDumper::DumpInstructionByGroupIndex(
const std::unique_ptr<cinn::hlir::framework::Instruction>& instr,
const int gidx) {
if (FLAGS_cinn_dump_group_instruction.empty() || instr.get() == nullptr) {
return;
}
Dump(FLAGS_cinn_dump_group_instruction,
gidx,
"instruction.txt",
instr->DumpInstruction());
}

void CompilationInfoDumper::DumpLoweredFunc() {
if (FLAGS_cinn_dump_group_lowered_func.empty()) {
return;
Expand Down
18 changes: 14 additions & 4 deletions paddle/cinn/backends/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,25 @@ class CompilationInfoDumper {
DumpInstruction();
}

static void DumpLoweredFuncByGroupIndex(const ir::LoweredFunc& lowered_func,
const int gidx);
static void DumpSourceCodeByGroupIndex(const std::string& source_code,
const int gidx);
static void DumpPtxCodeByGroupIndex(const std::string& source_ptx,
const int gidx);
static void DumpInstructionByGroupIndex(
const std::unique_ptr<cinn::hlir::framework::Instruction>& instr,
const int gidx);

private:
void DumpLoweredFunc();
void DumpSourceCode();
void DumpPtxCode();
void DumpInstruction();
void Dump(const std::string& base_path,
const int idx,
const std::string& file_name,
const std::string& content);
static void Dump(const std::string& base_path,
const int idx,
const std::string& file_name,
const std::string& content);

const hlir::framework::CompilationResult& info_;
};
Expand Down
3 changes: 0 additions & 3 deletions paddle/cinn/hlir/framework/graph_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ CompilationResult GraphCompiler::Build(CompilationContext* context) {
parallel_compiler_ = std::make_shared<ParallelCompiler>(context);
CompilationResult result = (*parallel_compiler_.get())();

// Dump compilation result
backends::CompilationInfoDumper dumper(result);

if (context->stage != CompilationStage::DEFAULT) {
return result;
}
Expand Down
27 changes: 15 additions & 12 deletions paddle/cinn/hlir/framework/parallel_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void ParallelCompiler::SplitTask() {
context_->graph->fusion_groups.size() ==
context_->lowered_funcs.size());
for (int i = 0; i < context_->graph->fusion_groups.size(); ++i) {
tasks_.emplace_back(this, context_, i);
tasks_.emplace_back(i, this, context_);
}
}

Expand Down Expand Up @@ -114,20 +114,17 @@ void ParallelCompiler::Task::Lowering() {
if (!context->lowered_funcs.empty()) {
CHECK_EQ(context->lowered_funcs.size(),
context->graph->fusion_groups.size());
}
auto& dtype_dict =
context->graph->GetMutableAttrs<absl::flat_hash_map<std::string, Type>>(
"inferdtype");
auto& shape_dict =
context->graph
->GetMutableAttrs<absl::flat_hash_map<std::string, shape_t>>(
"infershape");

OpLowerer op_lowerer(dtype_dict, shape_dict, context->target);
if (!context->lowered_funcs.empty()) {
pcompiler->result_.lowered_funcs[group_id] =
context->lowered_funcs[group_id];
} else {
auto& dtype_dict =
context->graph->GetMutableAttrs<absl::flat_hash_map<std::string, Type>>(
"inferdtype");
auto& shape_dict =
context->graph
->GetMutableAttrs<absl::flat_hash_map<std::string, shape_t>>(
"infershape");
OpLowerer op_lowerer(dtype_dict, shape_dict, context->target);
auto& group = context->graph->fusion_groups[group_id];
VLOG(4) << "Start Lowering Group " << group_id << " at "
<< std::this_thread::get_id() << " :\n"
Expand All @@ -138,6 +135,8 @@ void ParallelCompiler::Task::Lowering() {
CHECK_EQ(lowered_group.size(), 1) << "Lowerd Function Is Not Equal 1!";
pcompiler->result_.lowered_funcs[group_id] = std::move(lowered_group);
}
backends::CompilationInfoDumper::DumpLoweredFuncByGroupIndex(
pcompiler->result_.lowered_funcs[group_id].front(), group_id);
}

void ParallelCompiler::Task::CodegenAndJit() {
Expand Down Expand Up @@ -168,6 +167,8 @@ void ParallelCompiler::Task::CodegenAndJit() {
}
CHECK(!cuda_c.empty()) << "Compile CUDA C code failed from device module:\n"
<< dmodule;
backends::CompilationInfoDumper::DumpSourceCodeByGroupIndex(cuda_c,
group_id);
pcompiler->result_.source_codes[group_id] = cuda_c;

cinn::backends::SourceCodePrint::GetInstance()->write(cuda_c);
Expand All @@ -176,6 +177,7 @@ void ParallelCompiler::Task::CodegenAndJit() {
backends::nvrtc::Compiler compiler;
auto ptx = compiler(cuda_c);
CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << cuda_c;
backends::CompilationInfoDumper::DumpPtxCodeByGroupIndex(ptx, group_id);
pcompiler->result_.source_ptxs[group_id] = ptx;
// load cumodule
cumodule = std::make_unique<CUDAModule>(ptx,
Expand Down Expand Up @@ -217,6 +219,7 @@ void ParallelCompiler::Task::BuildInstruction() {
instr->SetLoweredFunc(reinterpret_cast<void*>(fn_ptr), group->GetFuncName());

instr->Finalize();
backends::CompilationInfoDumper::DumpInstructionByGroupIndex(instr, group_id);
pcompiler->result_.instructions[group_id] = std::move(instr);
}

Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/framework/parallel_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ namespace framework {
class ParallelCompiler {
public:
struct Task {
Task(ParallelCompiler* compiler, CompilationContext* context, int group_id)
: pcompiler(compiler), context(context), group_id(group_id) {}
Task(int group_id, ParallelCompiler* compiler, CompilationContext* context)
: group_id(group_id), pcompiler(compiler), context(context) {}
void Lowering();
void CodegenAndJit();
void BuildInstruction();
Expand Down
4 changes: 1 addition & 3 deletions paddle/fluid/framework/custom_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -946,9 +946,7 @@ static void RegisterOperatorKernel(
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (const auto& dev_type : device_types) {
for (size_t dev_id = 0;
dev_id < phi::DeviceManager::GetDeviceCount(dev_type);
dev_id++) {
for (auto& dev_id : phi::DeviceManager::GetSelectedDeviceList(dev_type)) {
RegisterOperatorKernelWithPlace(name,
op_kernel_func,
proto::VarType::RAW,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ set(SSA_GRAPH_EXECUTOR_DEPS
inplace_addto_op_pass
set_reader_device_info_utils)
cc_library(
ssa_graph_executor
ssa_graph_executor NOT_FOR_INFER
SRCS ssa_graph_executor.cc
DEPS ${SSA_GRAPH_EXECUTOR_DEPS})

Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/executor_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ std::unique_ptr<::ir::Program> ConstructFowardIrProgram(

auto op_desc = block->PrependOp();
op_desc->SetType("data");
op_desc->SetAttr("index", 0);
op_desc->SetAttr("shape", std::vector<int64_t>());
// TODO(phlrain) : using tensor dtype
op_desc->SetAttr("dtype", 0);
op_desc->SetAttr("place", static_cast<int>(place));
Expand All @@ -393,7 +393,7 @@ std::unique_ptr<::ir::Program> ConstructFowardIrProgram(

auto op_desc = local_program.MutableBlock(0)->PrependOp();
op_desc->SetType("data");
op_desc->SetAttr("index", 0);
op_desc->SetAttr("shape", std::vector<int64_t>());
// TODO(phlrain) : using tensor dtype
op_desc->SetAttr("dtype", 0);
op_desc->SetAttr("place", static_cast<int>(place));
Expand Down Expand Up @@ -479,7 +479,7 @@ std::unique_ptr<::ir::Program> ConstructBackwardIrProgram(
}
auto op_desc = local_program.MutableBlock(0)->PrependOp();
op_desc->SetType("data");
op_desc->SetAttr("index", 0);
op_desc->SetAttr("shape", std::vector<int64_t>());
// TODO(phlrain) : using tensor dtype
op_desc->SetAttr("dtype", 0);
op_desc->SetAttr("place", static_cast<int>(place));
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/feed_fetch_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License. */
#include "glog/logging.h"

PHI_DECLARE_bool(enable_new_ir_in_executor);
PHI_DECLARE_bool(enable_new_ir_api);

namespace phi {
class DenseTensor;
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ if(WITH_XPU)
pass_library(cast_mixed_precision_op_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(yolo_box_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(cast_embedding_trans_ids_to_int32_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(conv1d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(conv2d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(redundant_unsqueeze_squeeze_elimination_pass inference DIR xpu
Expand Down
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.