-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Feature/add proto attrs #2604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/add proto attrs #2604
Changes from all commits
b00d640
e815fe2
31e0531
deeef35
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -171,7 +171,7 @@ function(cc_library TARGET_NAME) | |
if (cc_library_DEPS) | ||
merge_static_libs(${TARGET_NAME} ${cc_library_DEPS}) | ||
else() | ||
message(FATAL "Please specify source file or library in cc_library.") | ||
message(FATAL_ERROR "Please specify source file or library in cc_library.") | ||
endif() | ||
endif(cc_library_SRCS) | ||
endfunction(cc_library) | ||
|
@@ -331,3 +331,42 @@ function(go_test TARGET_NAME) | |
add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS}) | ||
add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}) | ||
endfunction(go_test) | ||
|
||
# go_extern will download extern go project. | ||
# go_extern(target_name extern_source) | ||
# go_extern(go_redis github.com/hoisie/redis) | ||
function(go_extern TARGET_NAME) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why we need this function in generic.cmake? |
||
add_custom_target(${TARGET_NAME} env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get ${ARGN}) | ||
endfunction(go_extern) | ||
|
||
|
||
function(generate_protobuf_cpp SRCS HDRS) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to the documented convention at the beginning of this header file, what we need is |
||
set(PROTO_FILES ${ARGN}) | ||
set(${SRCS}) | ||
set(${HDRS}) | ||
foreach(FIL ${PROTO_FILES}) | ||
get_filename_component(ABS_FIL ${FIL} ABSOLUTE) | ||
get_filename_component(FIL_WE ${FIL} NAME_WE) | ||
if(NOT PROTOBUF_GENERATE_CPP_APPEND_PATH) | ||
get_filename_component(FIL_DIR ${FIL} DIRECTORY) | ||
if(FIL_DIR) | ||
set(FIL_WE "${FIL_DIR}/${FIL_WE}") | ||
endif() | ||
endif() | ||
|
||
list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.cc") | ||
list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.h") | ||
|
||
add_custom_command( | ||
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.cc" | ||
"${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.h" | ||
COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} | ||
ARGS "--cpp_out=${DLL_EXPORT_DECL}${CMAKE_CURRENT_BINARY_DIR}" "-I" ${CMAKE_CURRENT_SOURCE_DIR} ${ABS_FIL} | ||
DEPENDS ${ABS_FIL} protoc | ||
COMMENT "Running C++ protocol buffer compiler on ${FIL}" | ||
VERBATIM ) | ||
endforeach() | ||
set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE) | ||
set(${SRCS} ${${SRCS}} PARENT_SCOPE) | ||
set(${HDRS} ${${HDRS}} PARENT_SCOPE) | ||
endfunction() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,15 @@ | ||
cc_library(ddim SRCS ddim.cc) | ||
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) | ||
|
||
nv_test(dim_test SRCS dim_test.cu DEPS ddim) | ||
|
||
cc_test(variable_test SRCS variable_test.cc) | ||
# include generated protobuf headers | ||
INCLUDE_DIRECTORIES(${CMAKE_CURRENT_BINARY_DIR}) | ||
generate_protobuf_cpp(attr_proto_src attr_proto_hdr attr.proto) | ||
cc_library(attr_proto SRCS ${attr_proto_src} DEPS protobuf) | ||
generate_protobuf_cpp(attr_test_proto_src attr_test_proto_header attr_test.proto) | ||
message(STATUS ${attr_test_proto_src}) | ||
cc_library(attr_helper SRCS attr_helper.cc DEPS attr_proto) | ||
cc_test(attr_test SRCS ${attr_test_proto_src} attr_test.cc | ||
DEPS attr_proto attr_helper protobuf) | ||
cc_test(error_test SRCS error_test.cc) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
syntax="proto3"; | ||
package paddle.framework; | ||
|
||
message Attribute { | ||
message ListValue { | ||
repeated int32 ints = 1; | ||
repeated float floats = 2; | ||
repeated string strings = 3; | ||
} | ||
|
||
oneof value { | ||
ListValue list = 1; | ||
int32 i = 2; | ||
float f = 3; | ||
string s = 4; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#include "attr_helper.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
Error AttributeReader::NotFound("Attribute is not found"); | ||
Error AttributeReader::TypeMismatch("Type mismatched"); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
#pragma once | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't quite understand what goal all the code here is trying to achieve. It might be easier if there is a design doc like paddle/framework/variable.md for paddle/framework/variable.h. |
||
#include <google/protobuf/map.h> | ||
#include <paddle/framework/error.h> | ||
#include <iterator> | ||
#include <string> | ||
#include <type_traits> | ||
#include "attr.pb.h" | ||
namespace paddle { | ||
namespace framework { | ||
using AttributeMap = google::protobuf::Map<std::string, Attribute>; | ||
|
||
class AttributeReader final { | ||
public: | ||
static Error NotFound; | ||
static Error TypeMismatch; | ||
|
||
explicit AttributeReader(const AttributeMap& attrs) : attrs_(attrs) {} | ||
|
||
template <typename T> | ||
inline Error __must_check Get(const std::string& attributeName, | ||
T* attr) const; | ||
|
||
template <typename T> | ||
inline Error __must_check GetArray(const std::string& attributeName, | ||
std::vector<T>* array) const; | ||
|
||
private: | ||
const AttributeMap& attrs_; | ||
}; | ||
|
||
namespace details { | ||
template <typename Iterator, typename T> | ||
struct SetArrayImpl { | ||
inline Error __must_check operator()(AttributeMap* attrs, | ||
const std::string& attributeName, | ||
Iterator begin, Iterator end, | ||
bool overwrite); | ||
}; | ||
} // namespace details | ||
|
||
class AttributeWriter { | ||
public: | ||
explicit AttributeWriter(AttributeMap* attrs) : attrs_(attrs) {} | ||
|
||
template <typename T> | ||
inline Error __must_check Set(const std::string& attributeName, const T& attr, | ||
bool overwrite = false); | ||
|
||
template <typename Iterator> | ||
inline Error __must_check SetArray(const std::string& attributeName, | ||
Iterator begin, Iterator end, | ||
bool overwrite = false) { | ||
return details::SetArrayImpl< | ||
Iterator, typename std::iterator_traits<Iterator>::value_type>()( | ||
attrs_, attributeName, begin, end, overwrite); | ||
} | ||
|
||
template <typename T, typename Container = std::initializer_list<T>> | ||
inline Error __must_check SetArray(const std::string& attributeName, | ||
Container container, | ||
bool overwrite = false) { | ||
return SetArray(attributeName, container.begin(), container.end(), | ||
overwrite); | ||
} | ||
|
||
private: | ||
AttributeMap* attrs_; | ||
}; | ||
|
||
#define ATTR_READER_IMPL_PLAIN_TYPE(T, CASE, FIELD_NAME) \ | ||
template <> \ | ||
inline Error __must_check AttributeReader::Get<T>( \ | ||
const std::string& attributeName, T* attr) const { \ | ||
auto it = attrs_.find(attributeName); \ | ||
if (it == attrs_.end()) { \ | ||
return NotFound; \ | ||
} \ | ||
if (it->second.value_case() != CASE) { \ | ||
return TypeMismatch; \ | ||
} \ | ||
*attr = it->second.FIELD_NAME(); \ | ||
return Error(); \ | ||
} | ||
|
||
ATTR_READER_IMPL_PLAIN_TYPE(int, Attribute::kI, i); | ||
ATTR_READER_IMPL_PLAIN_TYPE(float, Attribute::kF, f); | ||
ATTR_READER_IMPL_PLAIN_TYPE(std::string, Attribute::kS, s); | ||
|
||
#undef ATTR_READER_IMPL_PLAIN_TYPE | ||
|
||
#define ATTR_READER_IMPL_ARRAY_TYPE(T, FIELD_NAME) \ | ||
template <> \ | ||
inline Error __must_check AttributeReader::GetArray<T>( \ | ||
const std::string& attributeName, std::vector<T>* array) const { \ | ||
if (!array->empty()) { \ | ||
return Error("The output array must be empty."); \ | ||
} \ | ||
\ | ||
auto it = attrs_.find(attributeName); \ | ||
if (it == attrs_.end()) { \ | ||
return NotFound; \ | ||
} \ | ||
\ | ||
auto& lst = it->second.list(); \ | ||
auto& field = lst.FIELD_NAME(); \ | ||
array->reserve(field.size()); \ | ||
std::copy(field.begin(), field.end(), std::back_inserter(*array)); \ | ||
return Error(); \ | ||
} | ||
|
||
ATTR_READER_IMPL_ARRAY_TYPE(float, floats); | ||
ATTR_READER_IMPL_ARRAY_TYPE(int, ints); | ||
ATTR_READER_IMPL_ARRAY_TYPE(std::string, strings); | ||
|
||
#undef ATTR_READER_IMPL_ARRAY_TYPE | ||
|
||
#define ATTR_WRITER_IMPL_PLAIN_TYPE(T, FIELD_NAME) \ | ||
template <> \ | ||
inline Error __must_check AttributeWriter::Set<T>( \ | ||
const std::string& attributeName, const T& attr, bool overwrite) { \ | ||
auto it = attrs_->find(attributeName); \ | ||
if (it != attrs_->end() && !overwrite) { \ | ||
return Error("Attribute %s has been set", attributeName.c_str()); \ | ||
} \ | ||
(*attrs_)[attributeName].set_##FIELD_NAME(attr); \ | ||
return Error(); \ | ||
} | ||
|
||
ATTR_WRITER_IMPL_PLAIN_TYPE(int, i); | ||
ATTR_WRITER_IMPL_PLAIN_TYPE(float, f); | ||
ATTR_WRITER_IMPL_PLAIN_TYPE(std::string, s); | ||
|
||
#undef ATTR_WRITER_IMPL_PLAIN_TYPE | ||
|
||
namespace details { | ||
template <typename T> | ||
inline void AppendToField(google::protobuf::RepeatedField<T>* field, | ||
const T& val) { | ||
field->Add(val); | ||
} | ||
template <typename T> | ||
inline void AppendToField(google::protobuf::RepeatedPtrField<T>* field, | ||
const T& val) { | ||
*(field->Add()) = val; | ||
} | ||
|
||
} // namespace details | ||
|
||
#define ATTR_WRITER_IMPL_ARRAY_TYPE(T, FIELD_NAME) \ | ||
namespace details { \ | ||
\ | ||
template <typename Iterator> \ | ||
struct SetArrayImpl<Iterator, T> { \ | ||
using VALUE_TYPE = typename std::iterator_traits<Iterator>::value_type; \ | ||
inline Error __must_check operator()(AttributeMap* attrs, \ | ||
const std::string& attributeName, \ | ||
Iterator begin, Iterator end, \ | ||
bool overwrite) { \ | ||
static_assert(std::is_same<VALUE_TYPE, T>::value, ""); \ | ||
auto it = attrs->find(attributeName); \ | ||
if (it != attrs->end() && !overwrite) { \ | ||
return Error("Attribute %s has been set", attributeName.c_str()); \ | ||
} \ | ||
\ | ||
if (it != attrs->end() && overwrite) { \ | ||
auto repeatedFieldPtr = \ | ||
it->second.mutable_list()->mutable_##FIELD_NAME(); \ | ||
repeatedFieldPtr->erase(repeatedFieldPtr->begin(), \ | ||
repeatedFieldPtr->end()); \ | ||
} \ | ||
auto lst = (*attrs)[attributeName].mutable_list(); \ | ||
auto elems = lst->mutable_##FIELD_NAME(); \ | ||
auto distance = std::distance(begin, end); \ | ||
if (std::is_integral<decltype(distance)>::value) { \ | ||
elems->Reserve(distance); \ | ||
} \ | ||
for (; begin != end; ++begin) { \ | ||
AppendToField(elems, *begin); \ | ||
} \ | ||
return Error(); \ | ||
} \ | ||
}; \ | ||
} | ||
|
||
ATTR_WRITER_IMPL_ARRAY_TYPE(float, floats); | ||
ATTR_WRITER_IMPL_ARRAY_TYPE(int, ints); | ||
ATTR_WRITER_IMPL_ARRAY_TYPE(std::string, strings); | ||
|
||
#undef ATTR_WRITER_IMPL_ARRAY_TYPE | ||
|
||
} // namespace framework | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove this check?