Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 48 additions & 20 deletions paddle/fluid/extension/include/op_meta_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,26 @@ inline std::string Grad(const std::string& var_name) {
using KernelFunc = std::vector<Tensor> (*)(std::vector<Tensor> inputs,
std::vector<boost::any> attrs);

#define PD_SPECIALIZE_ComputeCallHelper(attr_type) \
template <typename... Tail> \
struct ComputeCallHelper<attr_type, Tail...> { \
template <int in_idx, int attr_idx, typename... PreviousArgs> \
static Return Compute(std::vector<Tensor> inputs, \
std::vector<boost::any> attrs, \
const PreviousArgs&... pargs) { \
try { \
attr_type arg = boost::any_cast<attr_type>(attrs[attr_idx]); \
return ComputeCallHelper<Tail...>::template Compute<in_idx, \
attr_idx + 1>( \
inputs, attrs, pargs..., arg); \
} catch (boost::bad_any_cast&) { \
PD_THROW( \
"Attribute cast error in custom operator. Expected " #attr_type \
" value."); \
} \
} \
}

template <typename T>
struct TypeTag {};

Expand Down Expand Up @@ -114,26 +134,20 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
}
};

// TODO(chenweihang): add support for attribute input
// int attribute input (not used now)
template <typename... Tail>
struct ComputeCallHelper<int, Tail...> {
template <int in_idx, int attr_idx, typename... PreviousArgs>
static Return Compute(std::vector<Tensor> inputs,
std::vector<boost::any> attrs,
const PreviousArgs&... pargs) {
try {
int arg = boost::any_cast<int>(attrs[attr_idx]);
return ComputeCallHelper<Tail...>::template Compute<in_idx,
attr_idx + 1>(
inputs, attrs, pargs..., arg);
} catch (boost::bad_any_cast&) {
PD_THROW(
"Attribute cast error in custom operator. Expected int value.");
}
}
};

PD_SPECIALIZE_ComputeCallHelper(bool);
PD_SPECIALIZE_ComputeCallHelper(int);
PD_SPECIALIZE_ComputeCallHelper(float);
PD_SPECIALIZE_ComputeCallHelper(int64_t);
PD_SPECIALIZE_ComputeCallHelper(std::string);
PD_SPECIALIZE_ComputeCallHelper(std::vector<int>);
PD_SPECIALIZE_ComputeCallHelper(std::vector<float>);
PD_SPECIALIZE_ComputeCallHelper(std::vector<int64_t>);
PD_SPECIALIZE_ComputeCallHelper(std::vector<std::string>);
// TODO(chenweihang): support other attribute type if needed.
// Why not support other attribute type here?
// - boost::blank, std::vector<bool> and std::vector<double>
// are not used in op
// - BlockDesc* and std::vector<BlockDesc*> are used in framework
// end: base template
template <typename T>
struct ComputeCallHelper<TypeTag<T>> {
Expand Down Expand Up @@ -245,10 +259,23 @@ struct InferDtypeFuncImpl<Return (*)(Args...), impl_fn> {
class PD_DLL_DECL OpMetaInfo {
public:
explicit OpMetaInfo(const std::string& op_name) : name_(op_name) {}

// format: {"<name1>", "<name2>", ...}
OpMetaInfo& Inputs(std::vector<std::string>&& inputs);

// format: {"<name1>", "<name2>", ...}
OpMetaInfo& Outputs(std::vector<std::string>&& outputs);

// format: {"<name1>:<type1>", "<name1>:<type1>", ...}
OpMetaInfo& Attrs(std::vector<std::string>&& attrs);

// format: PD_KERNEL(...)
OpMetaInfo& SetKernelFn(KernelFunc&& func);

// format: PD_INFER_SHAPE(...)
OpMetaInfo& SetInferShapeFn(InferShapeFunc&& func);

// format: PD_INFER_DTYPE(...)
OpMetaInfo& SetInferDtypeFn(InferDtypeFunc&& func);

private:
Expand Down Expand Up @@ -297,6 +324,7 @@ class PD_DLL_DECL OpMetaInfoBuilder {
explicit OpMetaInfoBuilder(std::string&& name);
OpMetaInfoBuilder& Inputs(std::vector<std::string>&& inputs);
OpMetaInfoBuilder& Outputs(std::vector<std::string>&& outputs);
OpMetaInfoBuilder& Attrs(std::vector<std::string>&& attrs);
OpMetaInfoBuilder& SetKernelFn(KernelFunc func);
OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc func);
OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc func);
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/extension/src/op_meta_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ OpMetaInfo& OpMetaInfo::Outputs(std::vector<std::string>&& outputs) {
outputs_ = std::forward<std::vector<std::string>>(outputs);
return *this;
}
OpMetaInfo& OpMetaInfo::Attrs(std::vector<std::string>&& attrs) {
attrs_ = std::forward<std::vector<std::string>>(attrs);
return *this;
}
OpMetaInfo& OpMetaInfo::SetKernelFn(KernelFunc&& func) {
kernel_fn_ = std::forward<KernelFunc>(func);
return *this;
Expand Down Expand Up @@ -78,6 +82,11 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::Outputs(
return *this;
}

OpMetaInfoBuilder& OpMetaInfoBuilder::Attrs(std::vector<std::string>&& attrs) {
info_ptr_->Attrs(std::forward<std::vector<std::string>>(attrs));
return *this;
}

OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc func) {
info_ptr_->SetKernelFn(std::forward<KernelFunc>(func));
return *this;
Expand Down
132 changes: 118 additions & 14 deletions paddle/fluid/framework/custom_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,24 @@ inline bool IsMemberOf(const std::vector<std::string>& vec,
return std::find(vec.cbegin(), vec.cend(), name) != vec.cend();
}

std::vector<std::string> ParseAttrStr(const std::string& attr) {
auto split_pos = attr.find_first_of(":");
PADDLE_ENFORCE_NE(split_pos, std::string::npos,
platform::errors::InvalidArgument(
"Invalid attribute string format. Attribute string "
"format is `<name>:<type>`."));

std::vector<std::string> rlt;
// 1. name
rlt.emplace_back(string::trim_spaces(attr.substr(0, split_pos)));
// 2. type
rlt.emplace_back(string::trim_spaces(attr.substr(split_pos + 1)));

VLOG(1) << "attr name: " << rlt[0] << ", attr type str: " << rlt[1];

return rlt;
}

} // namespace detail

////////////////// Kernel Define ////////////////////
Expand All @@ -81,7 +99,8 @@ inline bool IsMemberOf(const std::vector<std::string>& vec,
static void RunKernelFunc(const framework::ExecutionContext& ctx,
const paddle::KernelFunc& func,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) {
VLOG(1) << "Custom Operator: Start run KernelFunc.";
std::vector<paddle::Tensor> custom_ins;
for (auto& in_name : inputs) {
Expand All @@ -98,10 +117,43 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
custom_ins.emplace_back(custom_in);
}

std::vector<boost::any> attrs;
std::vector<boost::any> custom_attrs;
for (auto& attr_str : attrs) {
auto attr_name_and_type = detail::ParseAttrStr(attr_str);
auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") {
custom_attrs.emplace_back(ctx.Attr<bool>(attr_name));
} else if (attr_type_str == "int") {
custom_attrs.emplace_back(ctx.Attr<int>(attr_name));
} else if (attr_type_str == "float") {
custom_attrs.emplace_back(ctx.Attr<float>(attr_name));
} else if (attr_type_str == "int64_t") {
custom_attrs.emplace_back(ctx.Attr<int64_t>(attr_name));
} else if (attr_type_str == "std::string") {
custom_attrs.emplace_back(ctx.Attr<std::string>(attr_name));
} else if (attr_type_str == "std::vector<int>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<int>>(attr_name));
} else if (attr_type_str == "std::vector<float>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<float>>(attr_name));
} else if (attr_type_str == "std::vector<int64_t>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<int64_t>>(attr_name));
} else if (attr_type_str == "std::vector<std::string>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<std::string>>(attr_name));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<int64_t>, "
"`std::vector<std::string>`, Please check whether "
"the attribute data type and data type string are matched.",
attr_type_str));
}
}

VLOG(1) << "Run ComputeFunc.";
auto outs = func(custom_ins, attrs);
auto outs = func(custom_ins, custom_attrs);

VLOG(1) << "Custom Operator: Share outputs into ExecutionContext.";
for (size_t i = 0; i < outputs.size(); ++i) {
Expand Down Expand Up @@ -164,7 +216,51 @@ class CustomOpMaker : public OpProtoAndCheckerMaker {
for (auto& out_name : outputs_) {
AddOutput(out_name, "The output " + out_name + "of Custom Operator.");
}
// TODO(chenweihang): support attrs in later PR
for (auto& attr : attrs_) {
auto attr_name_and_type = detail::ParseAttrStr(attr);
auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") {
AddAttr<bool>(attr_name, "custom operator bool attribute.")
.SetDefault(false);
} else if (attr_type_str == "int") {
AddAttr<int>(attr_name, "custom operator int attribute.").SetDefault(1);
} else if (attr_type_str == "float") {
AddAttr<float>(attr_name, "custom operator float attribute.")
.SetDefault(1.0f);
} else if (attr_type_str == "int64_t") {
AddAttr<int64_t>(attr_name, "custom operator int64_t attribute.")
.SetDefault(1);
} else if (attr_type_str == "std::string") {
AddAttr<std::string>(attr_name, "custom operator int attribute.")
.SetDefault("");
} else if (attr_type_str == "std::vector<int>") {
AddAttr<std::vector<int>>(attr_name,
"custom operator std::vector<int> attribute.")
.SetDefault({});
} else if (attr_type_str == "std::vector<float>") {
AddAttr<std::vector<float>>(
attr_name, "custom operator std::vector<float> attribute.")
.SetDefault({});
} else if (attr_type_str == "std::vector<int64_t>") {
AddAttr<std::vector<int64_t>>(
attr_name, "custom operator std::vector<int64_t> attribute.")
.SetDefault({});
} else if (attr_type_str == "std::vector<std::string>") {
AddAttr<std::vector<std::string>>(
attr_name, "custom operator std::vector<std::string> attribute.")
.SetDefault({});
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<int64_t>, "
"`std::vector<std::string>`, Please check whether "
"the attribute data type and data type string are matched.",
attr_type_str));
}
}
AddComment(R"DOC(
Custom Operator.

Expand Down Expand Up @@ -227,7 +323,7 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
VLOG(1) << "Custom Operator: GradOpDescMaker - output: " << out_name;
grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name)));
}
// TODO(chenweihang): support attrs in later PR
grad_op->SetAttrMap(this->Attrs());
}

private:
Expand Down Expand Up @@ -287,7 +383,7 @@ class CustomGradOpMaker<imperative::OpBase>
VLOG(1) << "Custom Operator: GradOpBaseMaker - output: " << out_name;
grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name)));
}
// TODO(chenweihang): support attrs in later PR
grad_op->SetAttrMap(this->Attrs());
}

private:
Expand All @@ -303,31 +399,36 @@ void RegisterOperatorKernelWithPlace(const std::string& name,
const proto::VarType::Type type,
const PlaceType& place,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) {
OpKernelType key(type,
CustomTensorUtils::ConvertEnumPlaceToInnerPlace(place));
VLOG(1) << "Custom Operator: op kernel key: " << key;
OperatorWithKernel::AllOpKernels()[name][key] =
[kernel_func, inputs, outputs](const framework::ExecutionContext& ctx) {
[kernel_func, inputs, outputs,
attrs](const framework::ExecutionContext& ctx) {
VLOG(1) << "Custom Operator: run custom kernel func in lambda.";
RunKernelFunc(ctx, kernel_func, inputs, outputs);
RunKernelFunc(ctx, kernel_func, inputs, outputs, attrs);
};
}

void RegisterOperatorKernel(const std::string& name,
const paddle::KernelFunc& kernel_func,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) {
VLOG(1) << "Custom Operator: op name in kernel: " << name;
// NOTE [ Dummy Op Kernel Key ]
// TODO(chenweihang): Because execute engine need get device context based
// op_kernel_key.place_, so we should register kernel for each
// device. But this is not entirely correct, if user only give a cpu kernel,
// but call api in gpu device, it will cause error.
RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW,
PlaceType::kCPU, inputs, outputs);
PlaceType::kCPU, inputs, outputs, attrs);
#ifdef PADDLE_WITH_CUDA
RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW,
PlaceType::kGPU, inputs, outputs);
PlaceType::kGPU, inputs, outputs, attrs);
#endif
}

void RegisterOperatorWithMetaInfo(
Expand All @@ -350,6 +451,8 @@ void RegisterOperatorWithMetaInfo(
<< string::join_strings(op_inputs, ',');
VLOG(1) << "Custom Operator: forward, op outputs: "
<< string::join_strings(op_outputs, ',');
VLOG(1) << "Custom Operator: forward, op attrs: "
<< string::join_strings(op_attrs, ',');

// Op
info.creator_ = [](const std::string& op_name, const VariableNameMap& inputs,
Expand Down Expand Up @@ -426,7 +529,7 @@ void RegisterOperatorWithMetaInfo(
};

// Kernel func
RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs);
RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs, op_attrs);

// If grad op or double grad op exists
std::string cur_op_name = op_name;
Expand All @@ -436,6 +539,7 @@ void RegisterOperatorWithMetaInfo(
auto& grad_op_name = OpMetaInfoHelper::GetOpName(cur_grad_op);
auto& grad_op_inputs = OpMetaInfoHelper::GetInputs(cur_grad_op);
auto& grad_op_outputs = OpMetaInfoHelper::GetOutputs(cur_grad_op);
auto& grad_op_attrs = OpMetaInfoHelper::GetAttrs(cur_grad_op);
auto& grad_kernel_fn = OpMetaInfoHelper::GetKernelFn(cur_grad_op);

VLOG(1) << "Custom Operator: backward, op name: " << grad_op_name;
Expand Down Expand Up @@ -489,7 +593,7 @@ void RegisterOperatorWithMetaInfo(

// Kernel func
RegisterOperatorKernel(grad_op_name, grad_kernel_fn, grad_op_inputs,
grad_op_outputs);
grad_op_outputs, grad_op_attrs);

// update current info
OpInfoMap::Instance().Insert(cur_op_name, info);
Expand Down
7 changes: 5 additions & 2 deletions python/paddle/fluid/tests/custom_op/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@ py_test(test_sysconfig SRCS test_sysconfig.py)

# 'test_dispatch' compile .cc file
py_test(test_dispatch_jit SRCS test_dispatch_jit.py)
set_tests_properties(test_dispatch_jit PROPERTIES TIMEOUT 180)
set_tests_properties(test_dispatch_jit PROPERTIES TIMEOUT 120)

py_test(test_multi_out_jit SRCS test_multi_out_jit.py)
set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 180)
set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 120)

py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py)
set_tests_properties(test_custom_attrs_jit PROPERTIES TIMEOUT 120)

if(NOT LINUX)
return()
Expand Down
Loading