Skip to content

Commit

Permalink
polish jit instruction (PaddlePaddle#58148)
Browse files Browse the repository at this point in the history
  • Loading branch information
phlrain authored Oct 18, 2023
1 parent 1a31840 commit 83ede11
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 50 deletions.
8 changes: 4 additions & 4 deletions paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ struct GroupInfoAttributeStorage : public pir::AttributeStorage {
ParamKey data_;
};

struct CUDAJITInfoAttributeStorage : public pir::AttributeStorage {
struct JITInfoAttributeStorage : public pir::AttributeStorage {
using ParamKey = cinn::hlir::framework::newir::CUDAJITInfo;
explicit CUDAJITInfoAttributeStorage(const ParamKey& key) : data_(key) {}
explicit JITInfoAttributeStorage(const ParamKey& key) : data_(key) {}

static CUDAJITInfoAttributeStorage* Construct(const ParamKey& key) {
return new CUDAJITInfoAttributeStorage(key);
static JITInfoAttributeStorage* Construct(const ParamKey& key) {
return new JITInfoAttributeStorage(key);
}

static std::size_t HashValue(const ParamKey& key) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/operator/ir/op_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class CUDAJITInfoAttribute : public pir::Attribute {
using Attribute::Attribute;

DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(CUDAJITInfoAttribute,
CUDAJITInfoAttributeStorage);
JITInfoAttributeStorage);

bool operator<(const CUDAJITInfoAttribute& right) const {
return storage() < right.storage();
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class JitKernelOp : public ::pir::Op<JitKernelOp> {
static const char* name() { return "cinn_runtime.jit_kernel"; }
// TODO(Aurelius84): Think deeply what should contains
static constexpr uint32_t attributes_num = 1;
static constexpr char* kAttrName = "cuda_jit_info";
static constexpr char* kAttrName = "jit_info";
static const char* attributes_name[attributes_num];

const hlir::framework::newir::CUDAJITInfo& cuda_jit_info();
Expand Down
11 changes: 6 additions & 5 deletions paddle/cinn/hlir/framework/new_ir_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,15 @@ std::vector<newir::CUDAJITInfo> NewIRCompiler::BuildCUDAJITInfo(
auto fn_ptrs = compiler_->GetFnPtr();

for (int idx = 0; idx < groups.size(); ++idx) {
newir::CUDAJITInfo node;
node.fn_ptr = fn_ptrs[idx];
newir::CUDAJITInfo jit_info;
jit_info.fn_ptr = fn_ptrs[idx];

lowered_funcs[idx][0]->cuda_axis_info.CopyBlockDimsTo(&(node.block_dims));
lowered_funcs[idx][0]->cuda_axis_info.CopyBlockDimsTo(
&(jit_info.block_dims));

lowered_funcs[idx][0]->cuda_axis_info.CopyGridDimsTo(&(node.grid_dims));
lowered_funcs[idx][0]->cuda_axis_info.CopyGridDimsTo(&(jit_info.grid_dims));

vec_res.push_back(node);
vec_res.push_back(jit_info);
}

return vec_res;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,37 +31,33 @@ class CinnJitInstruction::FnPtrImpl {
public:
explicit FnPtrImpl(const CUDAJITInfo& cuda_jit_info)
: cuda_jit_info_(cuda_jit_info) {}
// TODO(Aurelus84): Support to specify name2podargs and stream arguments.
void Run(const std::vector<phi::DenseTensor*>& kernel_args, void* stream) {
auto fn = static_cast<CUfunction>(cuda_jit_info_.fn_ptr);

auto stream1 = static_cast<CUstream>(stream);

pass_arg.clear();
vec_temp_.resize(kernel_args.size());
func_args_.clear();
ptr_storage_.resize(kernel_args.size());
for (size_t i = 0; i < kernel_args.size(); ++i) {
vec_temp_[i] = kernel_args[i]->data();
pass_arg.push_back(vec_temp_.data() + i);
ptr_storage_[i] = kernel_args[i]->data();
func_args_.push_back(ptr_storage_.data() + i);
}

CUDA_DRIVER_CALL(cuLaunchKernel(fn,
cuda_jit_info_.grid_dims[0],
cuda_jit_info_.grid_dims[1],
cuda_jit_info_.grid_dims[2],
cuda_jit_info_.block_dims[0],
cuda_jit_info_.block_dims[1],
cuda_jit_info_.block_dims[2],
0, // share memory
stream1,
pass_arg.data(),
nullptr))
CUDA_DRIVER_CALL(
cuLaunchKernel(static_cast<CUfunction>(cuda_jit_info_.fn_ptr),
cuda_jit_info_.grid_dims[0],
cuda_jit_info_.grid_dims[1],
cuda_jit_info_.grid_dims[2],
cuda_jit_info_.block_dims[0],
cuda_jit_info_.block_dims[1],
cuda_jit_info_.block_dims[2],
0, // share memory
static_cast<CUstream>(stream),
func_args_.data(),
nullptr))
}

private:
CUDAJITInfo cuda_jit_info_;

std::vector<void*> vec_temp_;
std::vector<void*> pass_arg;
std::vector<void*> ptr_storage_;
std::vector<void*> func_args_;
};

CinnJitInstruction::CinnJitInstruction(
Expand All @@ -70,9 +66,6 @@ CinnJitInstruction::CinnJitInstruction(
::pir::Operation* op,
const ValueExecutionInfo& value_exec_info)
: InstructionBase(id, place) {
// TODO(Aurelius84): We shall simplify members of JitKernelOp to make it
// only hold related function ptrs. Impl is the real runtime data structure
// responsible to construct hlir::framework::Instruction.
auto jit_kernel_op = op->dyn_cast<cinn::dialect::JitKernelOp>();
fn_ptr_impl_ = std::make_shared<FnPtrImpl>(jit_kernel_op.cuda_jit_info());
op_ = op;
Expand All @@ -90,7 +83,7 @@ CinnJitInstruction::CinnJitInstruction(
->Var(var_name)
->GetMutable<phi::DenseTensor>();

tensor_args.push_back(tensor);
tensor_args_.push_back(tensor);
}

dev_ctx_ = phi::DeviceContextPool::Instance().Get(place_);
Expand All @@ -103,7 +96,7 @@ CinnJitInstruction::CinnJitInstruction(
->Var(var_name)
->GetMutable<phi::DenseTensor>();

tensor_args.push_back(tensor);
tensor_args_.push_back(tensor);

out_tensor_ = tensor;

Expand All @@ -116,23 +109,17 @@ CinnJitInstruction::CinnJitInstruction(
}

void CinnJitInstruction::Run() {
// VLOG(6) << "Run cinn jit_kernel_op : " << Name();
// Get kernel input
auto gpu_ctx = static_cast<phi::GPUContext*>(dev_ctx_);

// gpu_ctx->Wait();
auto stream = gpu_ctx->stream();
for (size_t i = 0; i < tensor_args.size(); ++i) {
gpu_ctx->Alloc(tensor_args[i], phi::DataType::FLOAT32);
for (size_t i = 0; i < tensor_args_.size(); ++i) {
gpu_ctx->Alloc(tensor_args_[i], tensor_args_[i]->dtype());
}

fn_ptr_impl_->Run(tensor_args, static_cast<void*>(stream));
fn_ptr_impl_->Run(tensor_args_, static_cast<void*>(stream));
}

const std::string& CinnJitInstruction::Name() const {
// TODO(Aurelius84): Consider the case for instrucitons constaning
// multipule function ptrs and function names.
// return impl_->pointer()->function_name();
static const std::string name = "cinn_jit";
return name;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class CinnJitInstruction : public InstructionBase {

phi::DenseTensor* out_tensor_;

std::vector<phi::DenseTensor*> tensor_args;
std::vector<phi::DenseTensor*> tensor_args_;

::pir::Operation* op_{nullptr}; // not owned
};
Expand Down
3 changes: 1 addition & 2 deletions test/cpp/pir/cinn/jit_instruction_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,11 @@ namespace framework {
TEST(CinnJitInstruction, Run) {
// Step 1: Construct pir::Program
std::unique_ptr<::pir::Program> program = BuildProgram();
// EXPECT_EQ(program->block()->size(), 2u);
EXPECT_EQ(program->block()->size(), 7u);

// Step 2: Compiler New pir::Program into Runtime Program
auto target = cinn::common::DefaultNVGPUTarget();
auto scope = cinn::hlir::framework::BuildScope(target, *program);
// ASSERT_EQ(scope->var_names().size(), 2);

std::vector<cinn::hlir::framework::NewIRCompiler*> compiler_list;

Expand Down

0 comments on commit 83ede11

Please sign in to comment.