Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ struct GroupInfoAttributeStorage : public pir::AttributeStorage {
ParamKey data_;
};

struct JITInfoAttributeStorage : public pir::AttributeStorage {
using ParamKey = cinn::hlir::framework::pir::CUDAJITInfo;
explicit JITInfoAttributeStorage(const ParamKey& key) : data_(key) {}
struct CINNKernelInfoAttributeStorage : public pir::AttributeStorage {
using ParamKey = cinn::hlir::framework::pir::CINNKernelInfo;
explicit CINNKernelInfoAttributeStorage(const ParamKey& key) : data_(key) {}

static JITInfoAttributeStorage* Construct(const ParamKey& key) {
return new JITInfoAttributeStorage(key);
static CINNKernelInfoAttributeStorage* Construct(const ParamKey& key) {
return new CINNKernelInfoAttributeStorage(key);
}

static std::size_t HashValue(const ParamKey& key) {
Expand Down
6 changes: 3 additions & 3 deletions paddle/cinn/hlir/dialect/operator/ir/op_attribute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ const GroupInfo &GroupInfoAttribute::data() const {
return storage()->GetAsKey();
}

const cinn::hlir::framework::pir::CUDAJITInfo &CUDAJITInfoAttribute::data()
const {
const cinn::hlir::framework::pir::CINNKernelInfo &
CINNKernelInfoAttribute::data() const {
return storage()->GetAsKey();
}
} // namespace dialect
} // namespace cinn

IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::GroupInfoAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::CUDAJITInfoAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::CINNKernelInfoAttribute)
12 changes: 6 additions & 6 deletions paddle/cinn/hlir/dialect/operator/ir/op_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,22 @@ class GroupInfoAttribute : public pir::Attribute {
const GroupInfo& data() const;
};

class CUDAJITInfoAttribute : public pir::Attribute {
class CINNKernelInfoAttribute : public pir::Attribute {
public:
using Attribute::Attribute;

DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(CUDAJITInfoAttribute,
JITInfoAttributeStorage);
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(CINNKernelInfoAttribute,
CINNKernelInfoAttributeStorage);

bool operator<(const CUDAJITInfoAttribute& right) const {
bool operator<(const CINNKernelInfoAttribute& right) const {
return storage() < right.storage();
}

const cinn::hlir::framework::pir::CUDAJITInfo& data() const;
const cinn::hlir::framework::pir::CINNKernelInfo& data() const;
};

} // namespace dialect
} // namespace cinn

IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GroupInfoAttribute)
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::CUDAJITInfoAttribute)
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::CINNKernelInfoAttribute)
10 changes: 5 additions & 5 deletions paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void OperatorDialect::initialize() {
RegisterOp<ConcatOp>();
RegisterOp<SplitOp>();
RegisterAttribute<GroupInfoAttribute>();
RegisterAttribute<CUDAJITInfoAttribute>();
RegisterAttribute<CINNKernelInfoAttribute>();
}

void OperatorDialect::PrintType(pir::Type type, std::ostream &os) const {}
Expand All @@ -57,14 +57,14 @@ void OperatorDialect::PrintAttribute(pir::Attribute attr,
<< "[" << data.fn_name << "]";
}
{ os << "<#AttrNotImplemented>"; }
} else if (attr.isa<CUDAJITInfoAttribute>()) {
auto cuda_jit_info = attr.dyn_cast<CUDAJITInfoAttribute>();
} else if (attr.isa<CINNKernelInfoAttribute>()) {
auto cinn_kernel_info = attr.dyn_cast<CINNKernelInfoAttribute>();

os << "(" << cuda_jit_info.data().fn_ptr;
os << "(" << cinn_kernel_info.data().fn_ptr;
os << ')';
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"cinn dialect only support GrupInfo and CUDAJITInfo"));
"cinn dialect only support GroupInfo and CINNKernelInfo"));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class GroupOpPattern : public pir::OpRewritePattern<cinn::dialect::GroupOp> {
auto fn_ptr_res = ir_compiler->BuildCUDAJITInfo({group});
std::unordered_map<std::string, ::pir::Attribute> op_attrs{
{cinn::dialect::JitKernelOp::kAttrName,
cinn::dialect::CUDAJITInfoAttribute::get(ctx, fn_ptr_res[0])},
cinn::dialect::CINNKernelInfoAttribute::get(ctx, fn_ptr_res[0])},
};

// Generate jit kernel op input and output
Expand Down
12 changes: 6 additions & 6 deletions paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ void JitKernelOp::VerifySig() {

auto& attributes = this->attributes();

IR_ENFORCE(
attributes.count(kAttrName) > 0 &&
attributes.at(kAttrName).isa<cinn::dialect::CUDAJITInfoAttribute>(),
"Type of attribute: instruction is not right.");
IR_ENFORCE(attributes.count(kAttrName) > 0 &&
attributes.at(kAttrName)
.isa<cinn::dialect::CINNKernelInfoAttribute>(),
"Type of attribute: instruction is not right.");
}

const hlir::framework::pir::CUDAJITInfo& JitKernelOp::cuda_jit_info() {
const hlir::framework::pir::CINNKernelInfo& JitKernelOp::cinn_kernel_info() {
return attributes()
.at(kAttrName)
.dyn_cast<cinn::dialect::CUDAJITInfoAttribute>()
.dyn_cast<cinn::dialect::CINNKernelInfoAttribute>()
.data();
}

Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class JitKernelOp : public ::pir::Op<JitKernelOp> {
static const char* name() { return "cinn_runtime.jit_kernel"; }
// TODO(Aurelius84): Think deeply what should contains
static constexpr uint32_t attributes_num = 1;
static constexpr char* kAttrName = "jit_info";
static constexpr char* kAttrName = "kernel_info";
static const char* attributes_name[attributes_num];

static void Build(::pir::Builder& builder, // NOLINT
Expand All @@ -36,7 +36,7 @@ class JitKernelOp : public ::pir::Op<JitKernelOp> {
const ::pir::AttributeMap& attributes,
const std::vector<::pir::Type>& out_types);

const hlir::framework::pir::CUDAJITInfo& cuda_jit_info();
const hlir::framework::pir::CINNKernelInfo& cinn_kernel_info();

void VerifySig();
};
Expand Down
11 changes: 11 additions & 0 deletions paddle/cinn/hlir/framework/pir/compilation_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,17 @@ std::unique_ptr<Instruction> CompilationTask::BuildInstruction() {
return instr;
}

pir::CINNKernelInfo CompilationTask::BuildPirCINNKernelInfo() {
std::string fn_name = context_->group_->FuncName();
VLOG(4) << "Lookup kernel name: " << fn_name;
auto* fn_ptr = context_->backend_compiler_->Lookup(fn_name);
CHECK(fn_ptr);
pir::CINNKernelInfo cinn_kernel_info;
cinn_kernel_info.fn_ptr = fn_ptr;
cinn_kernel_info.int_args_map = context_->group_->int_args_map;
return cinn_kernel_info;
}

} // namespace framework
} // namespace hlir
} // namespace cinn
1 change: 1 addition & 0 deletions paddle/cinn/hlir/framework/pir/compilation_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class CompilationTask {
void Lowering();
void CodegenAndJit();
std::unique_ptr<Instruction> BuildInstruction();
pir::CINNKernelInfo BuildPirCINNKernelInfo();

private:
GroupCompilationContext* context_;
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/hlir/framework/pir/group.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ struct Group {
std::vector<std::string> output_names;
std::vector<::pir::Value> output_values;
std::string fn_name{""};
std::map<int, CINNKernelInfo::ArgDimIdx> int_args_map;

struct SharedGroupHasher {
size_t operator()(const std::shared_ptr<Group>& group) const noexcept {
Expand Down
44 changes: 34 additions & 10 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/cinn/hlir/op/external_api_registry.h"
#include "paddle/cinn/hlir/pe/map_expr_to_ir.h"
#include "paddle/cinn/ir/dim.h"
#include "paddle/cinn/ir/group_schedule/base_group_scheduler.h"
#include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
Expand Down Expand Up @@ -215,7 +216,7 @@ OpLowererImpl::BucketLower(const GroupPtr& group,
group_scheduler->Schedule();
cond2func_bodies = group_scheduler->GetIRs();
} else {
cond2func_bodies.emplace_back(ir::Expr(1),
cond2func_bodies.emplace_back(ir::Expr(true),
ir_sch.GetModule().GetExprs()[0]);
}

Expand Down Expand Up @@ -488,17 +489,40 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
for (auto arg : group_func_args) {
args_set.insert(arg.name());
}
for (auto& op : group->ops) {
// collect all output tensor.
for (auto opresult : op->results()) {
if (tensor_map.count(opresult) == 0) {
continue;
}
auto tensor = tensor_map.at(opresult);
if (args_set.count("_" + tensor->name) != 0) {
continue;
}
group->output_values.push_back(opresult);
group_func_arg_tensors->push_back(tensor);
group->output_names.push_back(tensor->name);
group_func_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput);
}
}
}

for (auto& tensor_pair : tensor_map) {
if (args_set.count("_" + tensor_pair.second->name)) {
continue;
std::map<int, CINNKernelInfo::ArgDimIdx> mps;
// update args for dynamic dim
int num_tensor_args = static_cast<int>(group_func_args.size());
int non_tensor_arg_idx = group_func_args.size();
for (int tensor_arg_idx = 0; tensor_arg_idx < num_tensor_args;
tensor_arg_idx++) {
auto tensor_dim = (*group_func_arg_tensors)[tensor_arg_idx]->sym_shape;
int tensor_dim_size = tensor_dim.size();
for (int tensor_arg_dim_idx = 0; tensor_arg_dim_idx < tensor_dim_size;
tensor_arg_dim_idx++) {
if (tensor_dim[tensor_arg_dim_idx]->IsDynamic()) {
group_func_args.emplace_back(ir::_Var_::Make(
tensor_dim[tensor_arg_dim_idx]->GetSymbolName(), common::Int(32)));
group->int_args_map[non_tensor_arg_idx++] = {tensor_arg_idx,
tensor_arg_dim_idx};
}
group_func_arg_tensors->push_back(tensor_pair.second);
// use the underlying tensor name to be consistent with the argument name
// in the lowered function
group->output_names.push_back(tensor_pair.second->name);
group_func_args.emplace_back(tensor_pair.second->buffer,
ir::Argument::IO::kOutput);
}
}

Expand Down
23 changes: 19 additions & 4 deletions paddle/cinn/hlir/framework/pir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,26 @@ namespace framework {

namespace pir {

struct CUDAJITInfo {
struct CINNKernelInfo {
void* fn_ptr;
std::vector<int> block_dims;
std::vector<int> grid_dims;
void* compiler;

struct ArgDimIdx {
int arg_idx;
int dim_idx;
};
// int_args_map records the int_args_map.key argument (dtype is Int) in the
// kernel parameter taken from the dim_idx dimension of the shape of the
// ArgDimIdx.arg_idx argument.
// Examples:
// a func like: foo(tensor A, tensor B, int S1, int S2)
// S1 = A.shape[3]
// S2 = B.shape[2]
// int_args_map will be like
// {
// 2: {0, 3},
// 3: {1, 2}
// }
std::map<int, ArgDimIdx> int_args_map;
};

struct CompatibleInfo {
Expand Down
65 changes: 35 additions & 30 deletions paddle/cinn/hlir/framework/pir_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,44 +42,49 @@ std::unique_ptr<Program> PirCompiler::Build() {
return std::move(Build(groups));
}

std::vector<pir::CUDAJITInfo> PirCompiler::BuildCUDAJITInfo(
std::vector<pir::CINNKernelInfo> PirCompiler::BuildCUDAJITInfo(
const std::vector<pir::GroupPtr>& groups) {
std::vector<pir::CUDAJITInfo> vec_res;
std::vector<pir::CINNKernelInfo> cinn_kernel_info_vecs(groups.size());

auto op_lowerer = CreateOpLowerer<pir::GroupPtr>(target_);

std::vector<std::vector<ir::LoweredFunc>> lowered_funcs;
for (int i = 0; i < groups.size(); ++i) {
lowered_funcs.emplace_back(op_lowerer.Lower(groups[i]));
}

for (auto&& lowered_func : lowered_funcs) {
ProcessFunction(lowered_func);
}

compiler_ = backends::Compiler::Create(target_);
auto build_module = m_builder_.Build();
compiler_->Build(build_module, "");

auto instructions = BuildInstructions(groups);
if (FLAGS_cinn_bucket_compile) {
for (int i = 0; i < groups.size(); ++i) {
group_compilation_contexts_.emplace_back(target_, groups[i], scope_);
}
auto worker_fn = [&](int index) {
CompilationTask task(&group_compilation_contexts_[index]);
task();
cinn_kernel_info_vecs[index] = task.BuildPirCINNKernelInfo();
};
utils::parallel_run(
worker_fn, utils::SequenceDispatcher(0, groups.size()), -1);
} else {
auto op_lowerer = CreateOpLowerer<pir::GroupPtr>(target_);

auto fn_ptrs = compiler_->GetFnPtr();
std::vector<std::vector<ir::LoweredFunc>> lowered_funcs;
for (int i = 0; i < groups.size(); ++i) {
lowered_funcs.emplace_back(op_lowerer.Lower(groups[i]));
}

auto* compilter_ptr = compiler_.release();
for (int idx = 0; idx < groups.size(); ++idx) {
pir::CUDAJITInfo jit_info;
jit_info.fn_ptr = fn_ptrs[idx];
jit_info.compiler = reinterpret_cast<void*>(compilter_ptr);
for (auto&& lowered_func : lowered_funcs) {
ProcessFunction(lowered_func);
}
compiler_ = backends::Compiler::Create(target_);
auto build_module = m_builder_.Build();
compiler_->Build(build_module, "");

lowered_funcs[idx][0]->cuda_axis_info.CopyBlockDimsTo(
&(jit_info.block_dims));
auto fn_ptrs = compiler_->GetFnPtr();

lowered_funcs[idx][0]->cuda_axis_info.CopyGridDimsTo(&(jit_info.grid_dims));
for (int idx = 0; idx < groups.size(); ++idx) {
pir::CINNKernelInfo cinn_kernel_info;
auto fn_name = groups[idx]->FuncName();
auto fn_ptr = compiler_->Lookup(fn_name);
cinn_kernel_info.fn_ptr = fn_ptr;
cinn_kernel_info.int_args_map = groups[idx]->int_args_map;

vec_res.push_back(jit_info);
cinn_kernel_info_vecs[idx] = cinn_kernel_info;
}
}

return vec_res;
return cinn_kernel_info_vecs;
}

std::unique_ptr<Program> PirCompiler::Build(
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/pir_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class PirCompiler final {

std::unique_ptr<Program> Build();

std::vector<pir::CUDAJITInfo> BuildCUDAJITInfo(
std::vector<pir::CINNKernelInfo> BuildCUDAJITInfo(
const std::vector<pir::GroupPtr>& groups);

std::unique_ptr<Program> Build(const std::vector<pir::GroupPtr>& groups);
Expand Down
Loading