-
Notifications
You must be signed in to change notification settings - Fork 5.7k
[draft] add registry for Op, OpProto and OpAttrChecker #2739
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
dd7fcb4
init op_registry.h
JiayiFeng 4b991f1
dev op_registry.h
JiayiFeng c4cc5e0
add 'attr_checker.h', which is a draft of op attribute checker.
JiayiFeng dd2479e
rename some macro parameters
JiayiFeng 2c37e68
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JiayiFeng c89fbac
1. Use `Attribute` and `AttributeMap` instead of `OpDesc`. `Attribute…
JiayiFeng a9923fd
rename DefaultChecker to DefaultValueSetter
JiayiFeng 4eff5b6
Finish op_registry
JiayiFeng fd46805
Add demo and test of custome checker
JiayiFeng c3c82e6
fix merge conflict
JiayiFeng 28e7cb3
fix merge conflict
JiayiFeng File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
#pragma once | ||
|
||
#include <boost/variant.hpp> | ||
#include <functional> | ||
#include <string> | ||
#include <unordered_map> | ||
#include <vector> | ||
#include "paddle/framework/enforce.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
typedef boost::variant<boost::blank, int, float, std::string, std::vector<int>, | ||
std::vector<float>, std::vector<std::string>> | ||
Attribute; | ||
typedef std::unordered_map<std::string, Attribute> AttributeMap; | ||
|
||
// check whether a value(attribute) fit a certain limit | ||
template <typename T> | ||
class LargerThanChecker { | ||
public: | ||
LargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {} | ||
void operator()(T& value) const { | ||
PADDLE_ENFORCE(value > lower_bound_, "larger_than check fail"); | ||
} | ||
|
||
private: | ||
T lower_bound_; | ||
}; | ||
|
||
// we can provide users more common Checker, like 'LessThanChecker', | ||
// 'BetweenChecker'... | ||
|
||
template <typename T> | ||
class DefaultValueSetter { | ||
public: | ||
DefaultValueSetter(T default_value) : default_value_(default_value) {} | ||
void operator()(T& value) const { value = default_value_; } | ||
|
||
private: | ||
T default_value_; | ||
}; | ||
|
||
// check whether a certain attribute fit its limits | ||
// an attribute can have more than one limits | ||
template <typename T> | ||
class TypedAttrChecker { | ||
typedef std::function<void(T&)> ValueChecker; | ||
|
||
public: | ||
TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {} | ||
|
||
TypedAttrChecker& LargerThan(const T& lower_bound) { | ||
value_checkers_.push_back(LargerThanChecker<T>(lower_bound)); | ||
return *this; | ||
} | ||
|
||
// we can add more common limits, like LessThan(), Between()... | ||
|
||
TypedAttrChecker& SetDefault(const T& default_value) { | ||
PADDLE_ENFORCE(default_value_setter_.empty(), | ||
"%s can't have more than one default value!", attr_name_); | ||
default_value_setter_.push_back(DefaultValueSetter<T>(default_value)); | ||
return *this; | ||
} | ||
|
||
// allow users provide their own checker | ||
TypedAttrChecker& AddCustomChecker(const ValueChecker& checker) { | ||
value_checkers_.push_back(checker); | ||
return *this; | ||
} | ||
|
||
void operator()(AttributeMap& attr_map) const { | ||
if (!attr_map.count(attr_name_)) { | ||
// user do not set this attr | ||
PADDLE_ENFORCE(!default_value_setter_.empty(), | ||
"Attribute '%s' is required!", attr_name_); | ||
// default_value_setter_ has no more than one element | ||
T val; | ||
(default_value_setter_[0])(val); | ||
attr_map[attr_name_] = val; | ||
} | ||
Attribute& attr = attr_map.at(attr_name_); | ||
T& attr_value = boost::get<T>(attr); | ||
for (const auto& checker : value_checkers_) { | ||
checker(attr_value); | ||
} | ||
} | ||
|
||
private: | ||
std::string attr_name_; | ||
std::vector<ValueChecker> value_checkers_; | ||
std::vector<ValueChecker> default_value_setter_; | ||
}; | ||
|
||
// check whether op's all attributes fit their own limits | ||
class OpAttrChecker { | ||
typedef std::function<void(AttributeMap&)> AttrChecker; | ||
|
||
public: | ||
template <typename T> | ||
TypedAttrChecker<T>& AddAttrChecker(const std::string& attr_name) { | ||
attr_checkers_.push_back(TypedAttrChecker<T>(attr_name)); | ||
AttrChecker& checker = attr_checkers_.back(); | ||
return *(checker.target<TypedAttrChecker<T>>()); | ||
} | ||
|
||
void Check(AttributeMap& attr_map) const { | ||
for (const auto& checker : attr_checkers_) { | ||
checker(attr_map); | ||
} | ||
} | ||
|
||
private: | ||
std::vector<AttrChecker> attr_checkers_; | ||
}; | ||
|
||
} // namespace framework | ||
} // namespace paddle |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
#pragma once | ||
|
||
#include "paddle/framework/attr_checker.h" | ||
|
||
//#include "paddle/framework/op_base.h" | ||
#include "paddle/framework/op_desc.pb.h" | ||
#include "paddle/framework/op_proto.pb.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
//==================For test================// | ||
class OpBase { | ||
public: | ||
std::vector<std::string> inputs_; | ||
std::vector<std::string> outputs_; | ||
AttributeMap attr_map_; | ||
|
||
virtual std::string Run() const = 0; | ||
virtual ~OpBase() {} | ||
}; | ||
//=========================================// | ||
|
||
// helper class to set attribute type | ||
struct AttrTypeHelper { | ||
template <typename T> | ||
static void SetAttrType(AttrProto* attr); | ||
|
||
static Attribute GetAttrValue(const AttrDesc& attr_desc) { | ||
switch (attr_desc.type()) { | ||
case paddle::framework::AttrType::INT: { | ||
return attr_desc.i(); | ||
} | ||
case paddle::framework::AttrType::FLOAT: { | ||
return attr_desc.f(); | ||
} | ||
case paddle::framework::AttrType::STRING: { | ||
return attr_desc.s(); | ||
} | ||
case paddle::framework::AttrType::INTS: { | ||
std::vector<int> val(attr_desc.ints_size()); | ||
for (int i = 0; i < attr_desc.ints_size(); ++i) { | ||
val[i] = attr_desc.ints(i); | ||
} | ||
return val; | ||
} | ||
case paddle::framework::AttrType::FLOATS: { | ||
std::vector<float> val(attr_desc.floats_size()); | ||
for (int i = 0; i < attr_desc.floats_size(); ++i) { | ||
val[i] = attr_desc.floats(i); | ||
} | ||
return val; | ||
} | ||
case paddle::framework::AttrType::STRINGS: { | ||
std::vector<std::string> val(attr_desc.strings_size()); | ||
for (int i = 0; i < attr_desc.strings_size(); ++i) { | ||
val[i] = attr_desc.strings(i); | ||
} | ||
return val; | ||
} | ||
} | ||
PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !"); | ||
return boost::blank(); | ||
} | ||
}; | ||
|
||
template <> | ||
void AttrTypeHelper::SetAttrType<int>(AttrProto* attr) { | ||
attr->set_type(paddle::framework::AttrType::INT); | ||
} | ||
|
||
template <> | ||
void AttrTypeHelper::SetAttrType<float>(AttrProto* attr) { | ||
attr->set_type(paddle::framework::AttrType::FLOAT); | ||
} | ||
|
||
template <> | ||
void AttrTypeHelper::SetAttrType<std::string>(AttrProto* attr) { | ||
attr->set_type(paddle::framework::AttrType::STRING); | ||
} | ||
|
||
template <> | ||
void AttrTypeHelper::SetAttrType<std::vector<int>>(AttrProto* attr) { | ||
attr->set_type(paddle::framework::AttrType::INTS); | ||
} | ||
|
||
template <> | ||
void AttrTypeHelper::SetAttrType<std::vector<float>>(AttrProto* attr) { | ||
attr->set_type(paddle::framework::AttrType::FLOATS); | ||
} | ||
|
||
template <> | ||
void AttrTypeHelper::SetAttrType<std::vector<std::string>>(AttrProto* attr) { | ||
attr->set_type(paddle::framework::AttrType::STRINGS); | ||
} | ||
|
||
// this class not only make proto but also init attribute checkers. | ||
class OpProtoAndCheckerMaker { | ||
public: | ||
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) | ||
: proto_(proto), op_checker_(op_checker) {} | ||
|
||
protected: | ||
void AddInput(const std::string& name, const std::string& comment) { | ||
auto input = proto_->mutable_inputs()->Add(); | ||
*(input->mutable_name()) = name; | ||
*(input->mutable_comment()) = comment; | ||
} | ||
|
||
void AddOutput(const std::string& name, const std::string& comment) { | ||
auto output = proto_->mutable_outputs()->Add(); | ||
*(output->mutable_name()) = name; | ||
*(output->mutable_comment()) = comment; | ||
} | ||
|
||
template <typename T> | ||
TypedAttrChecker<T>& AddAttr(const std::string& name, | ||
const std::string& comment) { | ||
auto attr = proto_->mutable_attrs()->Add(); | ||
*(attr->mutable_name()) = name; | ||
*(attr->mutable_comment()) = comment; | ||
AttrTypeHelper::SetAttrType<T>(attr); | ||
return op_checker_->AddAttrChecker<T>(name); | ||
} | ||
|
||
void AddType(const std::string& op_type) { proto_->set_type(op_type); } | ||
|
||
void AddComment(const std::string& comment) { | ||
*(proto_->mutable_comment()) = comment; | ||
} | ||
|
||
OpProto* proto_; | ||
OpAttrChecker* op_checker_; | ||
}; | ||
|
||
class OpRegistry { | ||
typedef std::function<OpBase*()> OpCreator; | ||
|
||
public: | ||
template <typename OpType, typename ProtoMakerType> | ||
static void RegisterOp(const std::string& op_type) { | ||
creators_[op_type] = []() { return new OpType; }; | ||
OpProto& op_proto = protos_[op_type]; | ||
OpAttrChecker& op_checker = op_checkers_[op_type]; | ||
ProtoMakerType(&op_proto, &op_checker); | ||
PADDLE_ENFORCE(op_proto.IsInitialized() == true, | ||
"Fail to initialize %s's OpProto !", op_type); | ||
} | ||
|
||
static OpBase* CreateOp(const OpDesc& op_desc) { | ||
std::string op_type = op_desc.type(); | ||
OpBase* op = (creators_.at(op_type))(); | ||
(op->inputs_).resize(op_desc.inputs_size()); | ||
for (int i = 0; i < op_desc.inputs_size(); ++i) { | ||
(op->inputs_)[i] = op_desc.inputs(i); | ||
} | ||
(op->outputs_).resize(op_desc.outputs_size()); | ||
for (int i = 0; i < op_desc.outputs_size(); ++i) { | ||
(op->outputs_)[i] = op_desc.outputs(i); | ||
} | ||
for (int i = 0; i < op_desc.attrs_size(); ++i) { | ||
const AttrDesc& ith_attr = op_desc.attrs(i); | ||
std::string name = ith_attr.name(); | ||
(op->attr_map_)[name] = AttrTypeHelper::GetAttrValue(ith_attr); | ||
} | ||
const OpAttrChecker& op_checker = op_checkers_.at(op_type); | ||
op_checker.Check(op->attr_map_); | ||
return op; | ||
} | ||
|
||
private: | ||
static std::unordered_map<std::string, OpCreator> creators_; | ||
static std::unordered_map<std::string, OpProto> protos_; | ||
static std::unordered_map<std::string, OpAttrChecker> op_checkers_; | ||
}; | ||
|
||
std::unordered_map<std::string, std::function<OpBase*()>> OpRegistry::creators_; | ||
std::unordered_map<std::string, OpProto> OpRegistry::protos_; | ||
std::unordered_map<std::string, OpAttrChecker> OpRegistry::op_checkers_; | ||
|
||
template <typename OpType, typename ProtoMakerType> | ||
class OpRegisterHelper { | ||
public: | ||
OpRegisterHelper(std::string op_type) { | ||
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type); | ||
} | ||
}; | ||
|
||
#define REGISTER_OP(__op_class, __op_maker_class, __op_type) \ | ||
class __op_class##Register { \ | ||
private: \ | ||
const static OpRegisterHelper<__op_class, __op_maker_class> reg; \ | ||
}; \ | ||
const OpRegisterHelper<__op_class, __op_maker_class> \ | ||
__op_class##Register::reg(#__op_type); | ||
|
||
// Demos | ||
|
||
class CosineOp : public OpBase { | ||
public: | ||
virtual std::string Run() const { | ||
std::string msg = "CosineOp runs! scale = " + | ||
std::to_string(boost::get<float>(attr_map_.at("scale"))); | ||
return msg; | ||
} | ||
}; | ||
|
||
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { | ||
public: | ||
CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("input", "input of cosine op"); | ||
AddOutput("output", "output of cosine op"); | ||
AddAttr<float>("scale", "scale of cosine op") | ||
.SetDefault(1.0) | ||
.LargerThan(0.0); | ||
AddType("cos"); | ||
AddComment("This is cos op"); | ||
} | ||
}; | ||
|
||
REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) | ||
|
||
class MyTestOp : public OpBase { | ||
public: | ||
virtual std::string Run() const { | ||
std::string msg = | ||
"MyTestOp runs! test_attr = " + | ||
std::to_string(boost::get<int>(attr_map_.at("test_attr"))); | ||
return msg; | ||
} | ||
}; | ||
|
||
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { | ||
public: | ||
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("input", "input of cosine op"); | ||
AddOutput("output", "output of cosine op"); | ||
auto my_checker = [](int i) { | ||
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); | ||
}; | ||
AddAttr<int>("test_attr", "a simple test attribute") | ||
.AddCustomChecker(my_checker); | ||
AddType("my_test_op"); | ||
AddComment("This is my_test op"); | ||
} | ||
}; | ||
|
||
REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op) | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should provide a base API named
Other API will invoke
AddChecker
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe
AddCustomChecker
has done the same thing?