Skip to content

[draft] add registry for Op, OpProto and OpAttrChecker #2739

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 7, 2017
1 change: 1 addition & 0 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_proto op_desc)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
Expand Down
119 changes: 119 additions & 0 deletions paddle/framework/attr_checker.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#pragma once

#include <boost/variant.hpp>
#include <functional>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/framework/enforce.h"

namespace paddle {
namespace framework {

typedef boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>>
Attribute;
typedef std::unordered_map<std::string, Attribute> AttributeMap;

// check whether a value(attribute) fit a certain limit
template <typename T>
class LargerThanChecker {
public:
LargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(T& value) const {
PADDLE_ENFORCE(value > lower_bound_, "larger_than check fail");
}

private:
T lower_bound_;
};

// we can provide users more common Checker, like 'LessThanChecker',
// 'BetweenChecker'...

template <typename T>
class DefaultValueSetter {
public:
DefaultValueSetter(T default_value) : default_value_(default_value) {}
void operator()(T& value) const { value = default_value_; }

private:
T default_value_;
};

// check whether a certain attribute fit its limits
// an attribute can have more than one limits
template <typename T>
class TypedAttrChecker {
typedef std::function<void(T&)> ValueChecker;

public:
TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {}

TypedAttrChecker& LargerThan(const T& lower_bound) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should provide a base API named

AddChecker(ValueChecker checker) {
  checkers_.push_back(checker);
  return *this;
}

Other API will invoke AddChecker

Copy link
Collaborator Author

@JiayiFeng JiayiFeng Jul 6, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe AddCustomChecker has done the same thing?

TypedAttrChecker& AddCustomChecker(const ValueChecker& checker) {
  value_checkers_.push_back(checker);
  return *this;
}

value_checkers_.push_back(LargerThanChecker<T>(lower_bound));
return *this;
}

// we can add more common limits, like LessThan(), Between()...

TypedAttrChecker& SetDefault(const T& default_value) {
PADDLE_ENFORCE(default_value_setter_.empty(),
"%s can't have more than one default value!", attr_name_);
default_value_setter_.push_back(DefaultValueSetter<T>(default_value));
return *this;
}

// allow users provide their own checker
TypedAttrChecker& AddCustomChecker(const ValueChecker& checker) {
value_checkers_.push_back(checker);
return *this;
}

void operator()(AttributeMap& attr_map) const {
if (!attr_map.count(attr_name_)) {
// user do not set this attr
PADDLE_ENFORCE(!default_value_setter_.empty(),
"Attribute '%s' is required!", attr_name_);
// default_value_setter_ has no more than one element
T val;
(default_value_setter_[0])(val);
attr_map[attr_name_] = val;
}
Attribute& attr = attr_map.at(attr_name_);
T& attr_value = boost::get<T>(attr);
for (const auto& checker : value_checkers_) {
checker(attr_value);
}
}

private:
std::string attr_name_;
std::vector<ValueChecker> value_checkers_;
std::vector<ValueChecker> default_value_setter_;
};

// check whether op's all attributes fit their own limits
class OpAttrChecker {
typedef std::function<void(AttributeMap&)> AttrChecker;

public:
template <typename T>
TypedAttrChecker<T>& AddAttrChecker(const std::string& attr_name) {
attr_checkers_.push_back(TypedAttrChecker<T>(attr_name));
AttrChecker& checker = attr_checkers_.back();
return *(checker.target<TypedAttrChecker<T>>());
}

void Check(AttributeMap& attr_map) const {
for (const auto& checker : attr_checkers_) {
checker(attr_map);
}
}

private:
std::vector<AttrChecker> attr_checkers_;
};

} // namespace framework
} // namespace paddle
253 changes: 253 additions & 0 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
#pragma once

#include "paddle/framework/attr_checker.h"

//#include "paddle/framework/op_base.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"

namespace paddle {
namespace framework {

//==================For test================//
class OpBase {
public:
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
AttributeMap attr_map_;

virtual std::string Run() const = 0;
virtual ~OpBase() {}
};
//=========================================//

// helper class to set attribute type
struct AttrTypeHelper {
template <typename T>
static void SetAttrType(AttrProto* attr);

static Attribute GetAttrValue(const AttrDesc& attr_desc) {
switch (attr_desc.type()) {
case paddle::framework::AttrType::INT: {
return attr_desc.i();
}
case paddle::framework::AttrType::FLOAT: {
return attr_desc.f();
}
case paddle::framework::AttrType::STRING: {
return attr_desc.s();
}
case paddle::framework::AttrType::INTS: {
std::vector<int> val(attr_desc.ints_size());
for (int i = 0; i < attr_desc.ints_size(); ++i) {
val[i] = attr_desc.ints(i);
}
return val;
}
case paddle::framework::AttrType::FLOATS: {
std::vector<float> val(attr_desc.floats_size());
for (int i = 0; i < attr_desc.floats_size(); ++i) {
val[i] = attr_desc.floats(i);
}
return val;
}
case paddle::framework::AttrType::STRINGS: {
std::vector<std::string> val(attr_desc.strings_size());
for (int i = 0; i < attr_desc.strings_size(); ++i) {
val[i] = attr_desc.strings(i);
}
return val;
}
}
PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !");
return boost::blank();
}
};

template <>
void AttrTypeHelper::SetAttrType<int>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::INT);
}

template <>
void AttrTypeHelper::SetAttrType<float>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::FLOAT);
}

template <>
void AttrTypeHelper::SetAttrType<std::string>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::STRING);
}

template <>
void AttrTypeHelper::SetAttrType<std::vector<int>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::INTS);
}

template <>
void AttrTypeHelper::SetAttrType<std::vector<float>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::FLOATS);
}

template <>
void AttrTypeHelper::SetAttrType<std::vector<std::string>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::STRINGS);
}

// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
public:
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {}

protected:
void AddInput(const std::string& name, const std::string& comment) {
auto input = proto_->mutable_inputs()->Add();
*(input->mutable_name()) = name;
*(input->mutable_comment()) = comment;
}

void AddOutput(const std::string& name, const std::string& comment) {
auto output = proto_->mutable_outputs()->Add();
*(output->mutable_name()) = name;
*(output->mutable_comment()) = comment;
}

template <typename T>
TypedAttrChecker<T>& AddAttr(const std::string& name,
const std::string& comment) {
auto attr = proto_->mutable_attrs()->Add();
*(attr->mutable_name()) = name;
*(attr->mutable_comment()) = comment;
AttrTypeHelper::SetAttrType<T>(attr);
return op_checker_->AddAttrChecker<T>(name);
}

void AddType(const std::string& op_type) { proto_->set_type(op_type); }

void AddComment(const std::string& comment) {
*(proto_->mutable_comment()) = comment;
}

OpProto* proto_;
OpAttrChecker* op_checker_;
};

class OpRegistry {
typedef std::function<OpBase*()> OpCreator;

public:
template <typename OpType, typename ProtoMakerType>
static void RegisterOp(const std::string& op_type) {
creators_[op_type] = []() { return new OpType; };
OpProto& op_proto = protos_[op_type];
OpAttrChecker& op_checker = op_checkers_[op_type];
ProtoMakerType(&op_proto, &op_checker);
PADDLE_ENFORCE(op_proto.IsInitialized() == true,
"Fail to initialize %s's OpProto !", op_type);
}

static OpBase* CreateOp(const OpDesc& op_desc) {
std::string op_type = op_desc.type();
OpBase* op = (creators_.at(op_type))();
(op->inputs_).resize(op_desc.inputs_size());
for (int i = 0; i < op_desc.inputs_size(); ++i) {
(op->inputs_)[i] = op_desc.inputs(i);
}
(op->outputs_).resize(op_desc.outputs_size());
for (int i = 0; i < op_desc.outputs_size(); ++i) {
(op->outputs_)[i] = op_desc.outputs(i);
}
for (int i = 0; i < op_desc.attrs_size(); ++i) {
const AttrDesc& ith_attr = op_desc.attrs(i);
std::string name = ith_attr.name();
(op->attr_map_)[name] = AttrTypeHelper::GetAttrValue(ith_attr);
}
const OpAttrChecker& op_checker = op_checkers_.at(op_type);
op_checker.Check(op->attr_map_);
return op;
}

private:
static std::unordered_map<std::string, OpCreator> creators_;
static std::unordered_map<std::string, OpProto> protos_;
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
};

std::unordered_map<std::string, std::function<OpBase*()>> OpRegistry::creators_;
std::unordered_map<std::string, OpProto> OpRegistry::protos_;
std::unordered_map<std::string, OpAttrChecker> OpRegistry::op_checkers_;

template <typename OpType, typename ProtoMakerType>
class OpRegisterHelper {
public:
OpRegisterHelper(std::string op_type) {
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type);
}
};

#define REGISTER_OP(__op_class, __op_maker_class, __op_type) \
class __op_class##Register { \
private: \
const static OpRegisterHelper<__op_class, __op_maker_class> reg; \
}; \
const OpRegisterHelper<__op_class, __op_maker_class> \
__op_class##Register::reg(#__op_type);

// Demos

class CosineOp : public OpBase {
public:
virtual std::string Run() const {
std::string msg = "CosineOp runs! scale = " +
std::to_string(boost::get<float>(attr_map_.at("scale")));
return msg;
}
};

class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddType("cos");
AddComment("This is cos op");
}
};

REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim)

class MyTestOp : public OpBase {
public:
virtual std::string Run() const {
std::string msg =
"MyTestOp runs! test_attr = " +
std::to_string(boost::get<int>(attr_map_.at("test_attr")));
return msg;
}
};

class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
};
AddAttr<int>("test_attr", "a simple test attribute")
.AddCustomChecker(my_checker);
AddType("my_test_op");
AddComment("This is my_test op");
}
};

REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op)

} // namespace framework
} // namespace paddle
Loading