Skip to content

fix cinn graph may hasn't input problem #40814

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion paddle/fluid/operators/cinn/cinn_instruction_run_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ class CinnInstructionRunOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs(kX), "Input", kX, "CinnInstructionRun");
// The cinn-graph may hasn't input for CINN now support fill_constant,
// and its all inputs may generated by fill_constant instead of by fetch.
// OP_INOUT_CHECK(ctx->HasInputs(kX), "Input", kX, "CinnInstructionRun");
OP_INOUT_CHECK(ctx->HasOutputs(kOutputs), "Output", kOutputs,
"CinnInstructionRun");
const CinnCompiledObject& compiled_object =
Expand All @@ -43,6 +45,53 @@ class CinnInstructionRunOp : public framework::OperatorWithKernel {
});
ctx->SetOutputsDim(kOutputs, output_dims);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// Why we need override GetExpectedKernelType?
// A cinn-graph may has no inpute var, if we use the base function,
// it will check wheter input tensors is initialized. Here we rewrite
// the function so that we can infer kernel type by output date type.
if (ctx.InputSize(kX)) {
// if the instruction has input, infer kernel type by input date type:
return OperatorWithKernel::GetExpectedKernelType(ctx);
}

// Else infer kernel type by output date type:
// The `OutputVar` will check wheter the kOutputs iff has one output var
const framework::Variable* var = ctx.OutputVar(kOutputs);
PADDLE_ENFORCE_NE(
var, nullptr,
platform::errors::InvalidArgument(
"The cinn_instruction_run Op's Output Variable should not empty."));

const framework::Tensor* tensor = nullptr;
if (var->IsType<framework::Tensor>()) {
tensor = &var->Get<framework::Tensor>();
} else if (var->IsType<framework::LoDTensor>()) {
tensor = &var->Get<framework::LoDTensor>();
} else if (var->IsType<phi::SelectedRows>()) {
tensor = &(var->Get<phi::SelectedRows>().value());
} else if (var->IsType<framework::LoDTensorArray>()) {
auto t_arr = &var->Get<framework::LoDTensorArray>();
PADDLE_ENFORCE_EQ(t_arr->size(), 1UL,
platform::errors::InvalidArgument(
"The cinn_instruction_run Op should just has One "
"Output when Input empty."));
tensor = &(t_arr->front());
}

PADDLE_ENFORCE_NE(
tensor, nullptr,
platform::errors::InvalidArgument(
"The cinn_instruction_run Op's Output Tensor should not empty."));

VLOG(4) << "The tensor [" << ctx.OutputName(kOutputs) << "]'s dtype is "
<< paddle::framework::DataType2String(tensor->dtype());
auto output_type = paddle::framework::TransToProtoVarType(tensor->dtype());
return framework::OpKernelType(output_type, ctx.device_context());
}
};

class CinnInstructionRunOpMaker : public framework::OpProtoAndCheckerMaker {
Expand Down
9 changes: 6 additions & 3 deletions paddle/fluid/operators/cinn/cinn_launch_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,12 @@ class CinnLaunchOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs(kX) || ctx->HasInputs(kNoNeedBufferX),
"Input", string::format_string("%s|%s", kX, kNoNeedBufferX),
"CinnLaunchOp");
// The cinn-graph may hasn't input for CINN now support fill_constant,
// and its all inputs may generated by fill_constant instead of by fetch.
// OP_INOUT_CHECK(ctx->HasInputs(kX) || ctx->HasInputs(kNoNeedBufferX),
// "Input", string::format_string("%s|%s", kX,
// kNoNeedBufferX),
// "CinnLaunchOp");
OP_INOUT_CHECK(ctx->HasOutputs(kOutputs), "Output", kOutputs,
"CinnLaunchOp");
}
Expand Down