Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmake/external/glog.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ ExternalProject_Add(

ADD_LIBRARY(glog STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARIES})
ADD_DEPENDENCIES(glog extern_glog)
ADD_DEPENDENCIES(glog extern_glog gflags)
LINK_LIBRARIES(glog gflags)

LIST(APPEND external_project_dependencies glog)
1 change: 1 addition & 0 deletions paddle/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ if(Boost_FOUND)
add_subdirectory(memory)
add_subdirectory(platform)
add_subdirectory(framework)
add_subdirectory(operators)
add_subdirectory(pybind)
endif()

Expand Down
4 changes: 2 additions & 2 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library(operator SRCS operator.cc DEPS op_desc protobuf device_context)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry place)
cc_library(operator SRCS operator.cc DEPS op_desc device_context)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
Expand Down
81 changes: 69 additions & 12 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <algorithm>
#include <type_traits>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
Expand Down Expand Up @@ -81,8 +82,6 @@ class OpProtoAndCheckerMaker {
return op_checker_->AddAttrChecker<T>(name);
}

void AddType(const std::string& op_type) { proto_->set_type(op_type); }

void AddComment(const std::string& comment) {
*(proto_->mutable_comment()) = comment;
}
Expand All @@ -101,8 +100,11 @@ class OpRegistry {
OpProto& op_proto = protos()[op_type];
OpAttrChecker& op_checker = op_checkers()[op_type];
ProtoMakerType(&op_proto, &op_checker);
PADDLE_ENFORCE(op_proto.IsInitialized(),
"Fail to initialize %s's OpProto !", op_type);
*op_proto.mutable_type() = op_type;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious why this can work before without set op_type

PADDLE_ENFORCE(
op_proto.IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_proto.InitializationErrorString());
}

static OperatorBase* CreateOp(const OpDesc& op_desc) {
Expand Down Expand Up @@ -143,18 +145,73 @@ class OpRegistry {
template <typename OpType, typename ProtoMakerType>
class OpRegisterHelper {
public:
OpRegisterHelper(std::string op_type) {
OpRegisterHelper(const char* op_type) {
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type);
}
};

#define REGISTER_OP(type, op_class, op_maker_class) \
class op_class##Register { \
private: \
const static OpRegisterHelper<op_class, op_maker_class> reg; \
}; \
const OpRegisterHelper<op_class, op_maker_class> op_class##Register::reg( \
#type)
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)

#define REGISTER_OP(__op_type, __op_class, __op_maker_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE(__reg_op__##__op_type, \
"REGISTER_OP must be in global namespace"); \
static ::paddle::framework::OpRegisterHelper<__op_class, __op_maker_class> \
__op_register_##__op_type##__(#__op_type); \
int __op_register_##__op_type##_handle__() { return 0; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里加上这个函数是什么作用啊?

int __op_register_##__op_type##_handle__() { return 0; }


#define REGISTER_OP_KERNEL(type, GPU_OR_CPU, PlaceType, KernelType) \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里目前先不考虑对不同数据类型kernel的支持吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前既然没有不同数据类型的实现,就先不考虑不同数据类型的问题。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

会有doublefloathalf等数据类型的支持吧,希望从接口下考虑下。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果需要增加数据类型的话,只是注册的时候加一个参数就好了。除了注册之外的接口没有区别。

STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##type##_##GPU_OR_CPU##__, \
"REGISTER_OP_KERNEL must be in global namespace"); \
struct __op_kernel_register__##type##__ { \
__op_kernel_register__##type##__() { \
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new KernelType()); \
} \
}; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__; \
int __op_kernel_register_##type##_handle_##GPU_OR_CPU##__() { return 0; }

#define REGISTER_OP_GPU_KERNEL(type, KernelType) \
REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, KernelType)

#define REGISTER_OP_CPU_KERNEL(type, KernelType) \
REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, KernelType)

#define USE_OP_WITHOUT_KERNEL(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_without_kernel_##op_type, \
"USE_OP_WITHOUT_KERNEL must be in global namespace"); \
extern int __op_register_##op_type##_handle__(); \
static int __use_op_ptr_##op_type##_without_kernel__ \
__attribute__((unused)) = __op_register_##op_type##_handle__()

#define USE_OP_KERNEL(op_type, DEVICE_TYPE) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_kernel_##op_type##_##DEVICE_TYPE##__, \
"USE_OP_KERNEL must be in global namespace"); \
extern int __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__(); \
static int __use_op_ptr_##op_type##_##DEVICE_TYPE##_kernel__ \
__attribute__((unused)) = \
__op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__()

#ifdef PADDLE_ONLY_CPU
#define USE_OP(op_type) \
USE_OP_WITHOUT_KERNEL(op_type); \
USE_OP_KERNEL(op_type, CPU);

#else
#define USE_OP(op_type) \
USE_OP_WITHOUT_KERNEL(op_type); \
USE_OP_KERNEL(op_type, CPU); \
USE_OP_KERNEL(op_type, GPU)
#endif

} // namespace framework
} // namespace paddle
24 changes: 8 additions & 16 deletions paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>

using namespace paddle::framework;

namespace paddle {
namespace framework {
class CosineOp : public OperatorBase {
Expand All @@ -21,13 +19,10 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddType("cos");
AddComment("This is cos op");
}
};

REGISTER_OP(cos_sim, CosineOp, CosineOpProtoAndCheckerMaker);

class MyTestOp : public OperatorBase {
public:
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
Expand All @@ -48,15 +43,17 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
};
AddAttr<int>("test_attr", "a simple test attribute")
.AddCustomChecker(my_checker);
AddType("my_test_op");
AddComment("This is my_test op");
}
};

REGISTER_OP(my_test_op, MyTestOp, MyTestOpProtoAndCheckerMaker);
} // namespace framework
} // namespace paddle

REGISTER_OP(cos_sim, paddle::framework::CosineOp,
paddle::framework::CosineOpProtoAndCheckerMaker);
REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
paddle::framework::MyTestOpProtoAndCheckerMaker);

TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
Expand All @@ -71,7 +68,7 @@ TEST(OpRegistry, CreateOp) {

paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<Scope>();
auto scope = std::make_shared<paddle::framework::Scope>();
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
float scale_get = op->GetAttr<float>("scale");
Expand Down Expand Up @@ -114,7 +111,7 @@ TEST(OpRegistry, DefaultValue) {

paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<Scope>();
auto scope = std::make_shared<paddle::framework::Scope>();
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
Expand Down Expand Up @@ -169,13 +166,8 @@ TEST(OpRegistry, CustomChecker) {
paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx;
auto scope = std::make_shared<Scope>();
auto scope = std::make_shared<paddle::framework::Scope>();
op->Run(scope, dev_ctx);
int test_attr = op->GetAttr<int>("test_attr");
ASSERT_EQ(test_attr, 4);
}

int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
53 changes: 42 additions & 11 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include <paddle/framework/attr_checker.h>
#include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/scope.h>
#include <paddle/framework/tensor.h>
#include <paddle/platform/device_context.h>
#include <paddle/platform/place.h>
#include <paddle/utils/Error.h>
Expand Down Expand Up @@ -103,6 +104,19 @@ class OpKernel {
virtual ~OpKernel() {}
};

template <typename T>
struct VarToTensor {};

template <>
struct VarToTensor<Tensor*> {
Tensor* operator()(Variable* var) { return var->GetMutable<Tensor>(); }
};

template <>
struct VarToTensor<const Tensor*> {
const Tensor* operator()(Variable* var) { return &var->Get<Tensor>(); }
};

class OperatorWithKernel : public OperatorBase {
public:
struct OpKernelKey {
Expand Down Expand Up @@ -136,19 +150,36 @@ class OperatorWithKernel : public OperatorBase {
AllOpKernels() {
static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
return g_all_op_kernels;
}
void InferShape(const std::shared_ptr<Scope>& scope) const final {
std::vector<const Tensor*> ins;
VarNamesToTensors(scope, inputs_, &ins);
std::vector<Tensor*> outs;
VarNamesToTensors(scope, outputs_, &outs);
InferShape(ins, outs);
};

private:
template <typename T>
void VarNamesToTensors(const std::shared_ptr<Scope>& scope,
const std::vector<std::string>& var_names,
std::vector<T>* container) const {
container->reserve(var_names.size());
VarToTensor<T> convert;
for (auto& name : var_names) {
auto var = scope->GetVariable(name);
if (var != nullptr) {
container->push_back(convert(var));
} else {
container->push_back(nullptr);
}
}
}

protected:
virtual void InferShape(const std::vector<const Tensor*>& inputs,
const std::vector<Tensor*>& outputs) const = 0;
};

} // namespace framework
} // namespace paddle

#define REGISTER_OP_KERNEL(type, PlaceType, KernelType) \
struct __op_kernel_register__##type##__ { \
__op_kernel_register__##type##__() { \
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new KernelType()); \
} \
}; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__
58 changes: 12 additions & 46 deletions paddle/framework/operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,44 +36,6 @@ class OperatorTest : public OperatorBase {
float x = 0;
};

class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OperatorTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op");
AddOutput("output", "output of test op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddType("test_operator");
AddComment("This is test op");
}
};

REGISTER_OP(test_operator, OperatorTest, OperatorTestProtoAndCheckerMaker);

TEST(OperatorBase, all) {
OpDesc op_desc;
op_desc.set_type("test_operator");
*op_desc.mutable_inputs()->Add() = "IN1";
*op_desc.mutable_outputs()->Add() = "OUT1";
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
float scale = 3.14;
attr->set_f(scale);

platform::CPUDeviceContext device_context;
auto scope = std::make_shared<Scope>();

OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc);
ASSERT_EQ(op->GetAttr<float>("scale"), scale);
scope->CreateVariable("OUT1");
op->Run(scope, device_context);
std::cout << op->DebugString() << std::endl;
delete op;
}

class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
Expand All @@ -83,14 +45,14 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddType("test_operator");
AddComment("This is test op");
}
};

class OpWithKernelTest : public OperatorWithKernel {
public:
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
protected:
void InferShape(const std::vector<const Tensor*>& inputs,
const std::vector<Tensor*>& outputs) const override {}
};

class CPUKernelTest : public OpKernel {
Expand All @@ -103,10 +65,16 @@ class CPUKernelTest : public OpKernel {
}
};

REGISTER_OP(op_with_kernel, OpWithKernelTest, OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_KERNEL(op_with_kernel, platform::CPUPlace, CPUKernelTest);
} // namespace framework
} // namespace paddle

REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest);

TEST(OpKernel, all) {
using namespace paddle::framework;

OpDesc op_desc;
op_desc.set_type("op_with_kernel");
*op_desc.mutable_inputs()->Add() = "IN1";
Expand All @@ -116,13 +84,11 @@ TEST(OpKernel, all) {
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(3.14);

platform::CPUDeviceContext cpu_device_context;
paddle::platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<Scope>();

OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context);

delete op;
}
} // namespace framework
} // namespace paddle
6 changes: 6 additions & 0 deletions paddle/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
if(WITH_GPU)
nv_library(add_op SRCS add_op.cc add_op.cu DEPS operator op_registry glog ddim)
else()
cc_library(add_op SRCS add_op.cc DEPS operator op_registry glog ddim)
endif()
cc_test(add_op_test SRCS add_op_test.cc DEPS add_op)
Loading