-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add a sample op, add_op
#2827
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a sample op, add_op
#2827
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
#pragma once | ||
|
||
#include <algorithm> | ||
#include <type_traits> | ||
#include "paddle/framework/attr_checker.h" | ||
#include "paddle/framework/op_desc.pb.h" | ||
#include "paddle/framework/op_proto.pb.h" | ||
|
@@ -81,8 +82,6 @@ class OpProtoAndCheckerMaker { | |
return op_checker_->AddAttrChecker<T>(name); | ||
} | ||
|
||
void AddType(const std::string& op_type) { proto_->set_type(op_type); } | ||
|
||
void AddComment(const std::string& comment) { | ||
*(proto_->mutable_comment()) = comment; | ||
} | ||
|
@@ -101,8 +100,11 @@ class OpRegistry { | |
OpProto& op_proto = protos()[op_type]; | ||
OpAttrChecker& op_checker = op_checkers()[op_type]; | ||
ProtoMakerType(&op_proto, &op_checker); | ||
PADDLE_ENFORCE(op_proto.IsInitialized(), | ||
"Fail to initialize %s's OpProto !", op_type); | ||
*op_proto.mutable_type() = op_type; | ||
PADDLE_ENFORCE( | ||
op_proto.IsInitialized(), | ||
"Fail to initialize %s's OpProto, because %s is not initialized", | ||
op_type, op_proto.InitializationErrorString()); | ||
} | ||
|
||
static OperatorBase* CreateOp(const OpDesc& op_desc) { | ||
|
@@ -143,18 +145,73 @@ class OpRegistry { | |
template <typename OpType, typename ProtoMakerType> | ||
class OpRegisterHelper { | ||
public: | ||
OpRegisterHelper(std::string op_type) { | ||
OpRegisterHelper(const char* op_type) { | ||
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type); | ||
} | ||
}; | ||
|
||
#define REGISTER_OP(type, op_class, op_maker_class) \ | ||
class op_class##Register { \ | ||
private: \ | ||
const static OpRegisterHelper<op_class, op_maker_class> reg; \ | ||
}; \ | ||
const OpRegisterHelper<op_class, op_maker_class> op_class##Register::reg( \ | ||
#type) | ||
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \ | ||
struct __test_global_namespace_##uniq_name##__ {}; \ | ||
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ | ||
__test_global_namespace_##uniq_name##__>::value, \ | ||
msg) | ||
|
||
#define REGISTER_OP(__op_type, __op_class, __op_maker_class) \ | ||
STATIC_ASSERT_GLOBAL_NAMESPACE(__reg_op__##__op_type, \ | ||
"REGISTER_OP must be in global namespace"); \ | ||
static ::paddle::framework::OpRegisterHelper<__op_class, __op_maker_class> \ | ||
__op_register_##__op_type##__(#__op_type); \ | ||
int __op_register_##__op_type##_handle__() { return 0; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里加上这个函数是什么作用啊?
|
||
|
||
#define REGISTER_OP_KERNEL(type, GPU_OR_CPU, PlaceType, KernelType) \ | ||
|
||
STATIC_ASSERT_GLOBAL_NAMESPACE( \ | ||
__reg_op_kernel_##type##_##GPU_OR_CPU##__, \ | ||
"REGISTER_OP_KERNEL must be in global namespace"); \ | ||
struct __op_kernel_register__##type##__ { \ | ||
__op_kernel_register__##type##__() { \ | ||
::paddle::framework::OperatorWithKernel::OpKernelKey key; \ | ||
key.place_ = PlaceType(); \ | ||
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \ | ||
.reset(new KernelType()); \ | ||
} \ | ||
}; \ | ||
static __op_kernel_register__##type##__ __reg_kernel_##type##__; \ | ||
int __op_kernel_register_##type##_handle_##GPU_OR_CPU##__() { return 0; } | ||
|
||
#define REGISTER_OP_GPU_KERNEL(type, KernelType) \ | ||
REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, KernelType) | ||
|
||
#define REGISTER_OP_CPU_KERNEL(type, KernelType) \ | ||
REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, KernelType) | ||
|
||
#define USE_OP_WITHOUT_KERNEL(op_type) \ | ||
STATIC_ASSERT_GLOBAL_NAMESPACE( \ | ||
__use_op_without_kernel_##op_type, \ | ||
"USE_OP_WITHOUT_KERNEL must be in global namespace"); \ | ||
extern int __op_register_##op_type##_handle__(); \ | ||
static int __use_op_ptr_##op_type##_without_kernel__ \ | ||
__attribute__((unused)) = __op_register_##op_type##_handle__() | ||
|
||
#define USE_OP_KERNEL(op_type, DEVICE_TYPE) \ | ||
STATIC_ASSERT_GLOBAL_NAMESPACE( \ | ||
__use_op_kernel_##op_type##_##DEVICE_TYPE##__, \ | ||
"USE_OP_KERNEL must be in global namespace"); \ | ||
extern int __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__(); \ | ||
static int __use_op_ptr_##op_type##_##DEVICE_TYPE##_kernel__ \ | ||
__attribute__((unused)) = \ | ||
__op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__() | ||
|
||
#ifdef PADDLE_ONLY_CPU | ||
#define USE_OP(op_type) \ | ||
USE_OP_WITHOUT_KERNEL(op_type); \ | ||
USE_OP_KERNEL(op_type, CPU); | ||
|
||
#else | ||
#define USE_OP(op_type) \ | ||
USE_OP_WITHOUT_KERNEL(op_type); \ | ||
USE_OP_KERNEL(op_type, CPU); \ | ||
USE_OP_KERNEL(op_type, GPU) | ||
#endif | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
if(WITH_GPU) | ||
nv_library(add_op SRCS add_op.cc add_op.cu DEPS operator op_registry glog ddim) | ||
else() | ||
cc_library(add_op SRCS add_op.cc DEPS operator op_registry glog ddim) | ||
endif() | ||
cc_test(add_op_test SRCS add_op_test.cc DEPS add_op) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am curious why this can work before without set op_type