Skip to content

[CustomOp] Support output as input argument of kernel func #39353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 82 additions & 43 deletions paddle/fluid/framework/custom_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) {
VLOG(3) << "Custom Operator: Start run KernelFunc.";
std::vector<paddle::experimental::Tensor> custom_ins;
std::vector<std::vector<paddle::experimental::Tensor>> custom_vec_ins;
// prepare CustomOpKernelContext
paddle::CustomOpKernelContext kernel_ctx;
for (auto& in_name : inputs) {
VLOG(3) << "Custom Operator: input name - " << in_name;
if (detail::IsDuplicableVar(in_name)) {
Expand All @@ -136,7 +136,7 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
custom_t.set_impl(std::make_shared<pten::DenseTensor>(*x));
custom_vec_in.emplace_back(custom_t);
}
custom_vec_ins.emplace_back(custom_vec_in);
kernel_ctx.EmplaceBackInputs(std::move(custom_vec_in));
} else {
auto* x = ctx.Input<Tensor>(in_name);
PADDLE_ENFORCE_NOT_NULL(x, platform::errors::NotFound(
Expand All @@ -146,33 +146,32 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
"Input tensor (%s) is not initialized.", in_name));
paddle::experimental::Tensor custom_in;
custom_in.set_impl(std::make_shared<pten::DenseTensor>(*x));
custom_ins.emplace_back(custom_in);
kernel_ctx.EmplaceBackInput(std::move(custom_in));
}
}

std::vector<paddle::any> custom_attrs;
for (auto& attr_str : attrs) {
auto attr_name_and_type = detail::ParseAttrStr(attr_str);
auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") {
custom_attrs.emplace_back(ctx.Attr<bool>(attr_name));
kernel_ctx.EmplaceBackAttr(ctx.Attr<bool>(attr_name));
} else if (attr_type_str == "int") {
custom_attrs.emplace_back(ctx.Attr<int>(attr_name));
kernel_ctx.EmplaceBackAttr(ctx.Attr<int>(attr_name));
} else if (attr_type_str == "float") {
custom_attrs.emplace_back(ctx.Attr<float>(attr_name));
kernel_ctx.EmplaceBackAttr(ctx.Attr<float>(attr_name));
} else if (attr_type_str == "int64_t") {
custom_attrs.emplace_back(ctx.Attr<int64_t>(attr_name));
kernel_ctx.EmplaceBackAttr(ctx.Attr<int64_t>(attr_name));
} else if (attr_type_str == "std::string") {
custom_attrs.emplace_back(ctx.Attr<std::string>(attr_name));
kernel_ctx.EmplaceBackAttr(ctx.Attr<std::string>(attr_name));
} else if (attr_type_str == "std::vector<int>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<int>>(attr_name));
kernel_ctx.EmplaceBackAttr(ctx.Attr<std::vector<int>>(attr_name));
} else if (attr_type_str == "std::vector<float>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<float>>(attr_name));
kernel_ctx.EmplaceBackAttr(ctx.Attr<std::vector<float>>(attr_name));
} else if (attr_type_str == "std::vector<int64_t>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<int64_t>>(attr_name));
kernel_ctx.EmplaceBackAttr(ctx.Attr<std::vector<int64_t>>(attr_name));
} else if (attr_type_str == "std::vector<std::string>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<std::string>>(attr_name));
kernel_ctx.EmplaceBackAttr(ctx.Attr<std::vector<std::string>>(attr_name));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. "
Expand All @@ -185,35 +184,75 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
}
}

VLOG(3) << "Custom Operator: Run ComputeFunc.";
try {
auto outs = func(custom_ins, custom_vec_ins, custom_attrs);
VLOG(3) << "Custom Operator: push outputs into CustomOpKernelContext.";
// cache the target tensor pointers
std::vector<Tensor*> true_out_ptrs;
for (size_t i = 0; i < outputs.size(); ++i) {
auto out_name = outputs[i];
if (detail::IsDuplicableVar(out_name)) {
PADDLE_ENFORCE(i == 0UL && outputs.size() == 1UL,
platform::errors::PreconditionNotMet(
"If custom operator's outputs contains `paddle::Vec("
")` type, "
"it only can hold one output."));
auto vec_out = ctx.MultiOutput<Tensor>(out_name);
PADDLE_ENFORCE_NE(vec_out.empty(), true,
platform::errors::NotFound(
"Output vector<tensor> (%s) is empty.", out_name));
std::vector<paddle::experimental::Tensor> custom_vec_out;
for (size_t j = 0; j < vec_out.size(); ++j) {
auto* out = vec_out[j];
PADDLE_ENFORCE_NOT_NULL(
out,
platform::errors::NotFound(
"The %d-th tensor in output vector<tensor> (%s) is nullptr.", j,
out_name));
true_out_ptrs.emplace_back(out);
paddle::experimental::Tensor custom_t;
// here only can copy the output tensor into context
custom_t.set_impl(std::make_shared<pten::DenseTensor>(*out));
custom_vec_out.emplace_back(custom_t);
}
kernel_ctx.EmplaceBackOutputs(std::move(custom_vec_out));
} else {
auto* out = ctx.Output<Tensor>(out_name);
PADDLE_ENFORCE_NOT_NULL(
out, platform::errors::NotFound("Output tensor (%s) is nullptr.",
out_name));
true_out_ptrs.emplace_back(out);
paddle::experimental::Tensor custom_out;
// here only can copy the output tensor into context
custom_out.set_impl(std::make_shared<pten::DenseTensor>(*out));
kernel_ctx.EmplaceBackOutput(std::move(custom_out));
}
}

VLOG(3) << "Custom Operator: Share outputs into ExecutionContext.";
for (size_t i = 0; i < outputs.size(); ++i) {
auto out_name = outputs[i];
if (detail::IsDuplicableVar(out_name)) {
PADDLE_ENFORCE(i == 0UL && outputs.size() == 1UL,
platform::errors::PreconditionNotMet(
"If custom operator's outputs contains `paddle::Vec("
")` type, "
"it only can hold one output."));
auto vec_true_outs = ctx.MultiOutput<Tensor>(out_name);
PADDLE_ENFORCE_EQ(
vec_true_outs.size(), outs.size(),
platform::errors::InvalidArgument(
"The number of element in custom operator outputs is wrong, "
"expected contains %d Tensors, but actually contains %d "
"Tensors.",
vec_true_outs.size(), outs.size()));
for (size_t j = 0; j < vec_true_outs.size(); ++j) {
*vec_true_outs.at(j) =
*std::dynamic_pointer_cast<pten::DenseTensor>(outs.at(j).impl());
}
} else {
auto* true_out = ctx.Output<Tensor>(out_name);
*true_out =
*std::dynamic_pointer_cast<pten::DenseTensor>(outs.at(i).impl());
try {
VLOG(3) << "Custom Operator: Run ComputeFunc.";
func(&kernel_ctx);

// sync output tensor data into original output
auto* calc_outs = kernel_ctx.AllMutableOutput();
PADDLE_ENFORCE_EQ(
true_out_ptrs.size(), calc_outs->size(),
platform::errors::InvalidArgument(
"The number of element in custom operator outputs is wrong, "
"expected contains %d Tensors, but actually contains %d "
"Tensors.",
true_out_ptrs.size(), calc_outs->size()));
for (size_t i = 0; i < true_out_ptrs.size(); ++i) {
auto* true_out = true_out_ptrs.at(i);
auto calc_out =
std::dynamic_pointer_cast<pten::DenseTensor>(calc_outs->at(i).impl());
// assgin meta info
auto* true_out_meta = pten::DenseTensorUtils::GetMutableMeta(true_out);
true_out_meta->dims = calc_out->dims();
true_out_meta->dtype = calc_out->dtype();
true_out_meta->layout = calc_out->layout();
// lod and offset no need to be reset
// reset holder if needed
if (true_out->Holder() != calc_out->Holder()) {
true_out->ResetHolder(calc_out->Holder());
}
}
} catch (platform::EnforceNotMet& exception) {
Expand Down Expand Up @@ -609,7 +648,7 @@ void RegisterOperatorWithMetaInfo(
auto op_name = OpMetaInfoHelper::GetOpName(base_op_meta);

if (OpInfoMap::Instance().Has(op_name)) {
LOG(WARNING) << "Operator (" << op_name << ")has been registered.";
LOG(WARNING) << "Operator (" << op_name << ") has been registered.";
return;
}

Expand Down
Loading