Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions cmake/external/protobuf.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,15 @@ macro(PROMPT_PROTOBUF_LIB)
ADD_LIBRARY(protobuf_lite ${protobuf_LIBTYPE} IMPORTED GLOBAL)
SET_PROPERTY(TARGET protobuf_lite PROPERTY IMPORTED_LOCATION ${PROTOBUF_LITE_LIBRARY})

ADD_LIBRARY(protoc ${protobuf_LIBTYPE} IMPORTED GLOBAL)
SET_PROPERTY(TARGET protoc PROPERTY IMPORTED_LOCATION ${PROTOC_LIBRARY})
ADD_LIBRARY(libprotoc ${protobuf_LIBTYPE} IMPORTED GLOBAL)
SET_PROPERTY(TARGET libprotoc PROPERTY IMPORTED_LOCATION ${PROTOC_LIBRARY})
ADD_EXECUTABLE(protoc IMPORTED GLOBAL)
SET_PROPERTY(TARGET protoc PROPERTY IMPORTED_LOCATION ${PROTOBUF_PROTOC_EXECUTABLE})

FOREACH(dep ${protobuf_DEPS})
ADD_DEPENDENCIES(protobuf ${dep})
ADD_DEPENDENCIES(protobuf_lite ${dep})
ADD_DEPENDENCIES(libprotoc ${dep})
ADD_DEPENDENCIES(protoc ${dep})
ENDFOREACH()

Expand Down Expand Up @@ -114,7 +117,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
UPDATE_COMMAND ""
DEPENDS zlib
GIT_REPOSITORY "https://github.com/google/protobuf.git"
GIT_TAG "9f75c5aa851cd877fb0d93ccc31b8567a6706546"
GIT_TAG "v3.1.0"
CONFIGURE_COMMAND
${CMAKE_COMMAND} ${PROTOBUF_SOURCES_DIR}/src/${TARGET_NAME}/cmake
${OPTIONAL_ARGS}
Expand Down
5 changes: 3 additions & 2 deletions cmake/flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,16 @@ set(COMMON_FLAGS
-fPIC
-fno-omit-frame-pointer
-Wall
-Wextra
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove this check?

-Werror
-Wnon-virtual-dtor
-Wdelete-non-virtual-dtor
-Wno-unused-parameter
-Wno-unused-function
-Wno-error=literal-suffix
-Wno-error=sign-compare
-Wno-error=unused-local-typedefs)
-Wno-error=unused-local-typedefs
-Wno-error=ignored-qualifiers # Warning in protobuf 3 Map.h
-Wno-error=no-enum-compare) # Warning in protobuf 3 Map.h

set(GPU_COMMON_FLAGS
-fPIC
Expand Down
41 changes: 40 additions & 1 deletion cmake/generic.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ function(cc_library TARGET_NAME)
if (cc_library_DEPS)
merge_static_libs(${TARGET_NAME} ${cc_library_DEPS})
else()
message(FATAL "Please specify source file or library in cc_library.")
message(FATAL_ERROR "Please specify source file or library in cc_library.")
endif()
endif(cc_library_SRCS)
endfunction(cc_library)
Expand Down Expand Up @@ -331,3 +331,42 @@ function(go_test TARGET_NAME)
add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS})
add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME})
endfunction(go_test)

# go_extern will download extern go project.
# go_extern(target_name extern_source)
# go_extern(go_redis github.com/hoisie/redis)
function(go_extern TARGET_NAME)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need this function in generic.cmake?

add_custom_target(${TARGET_NAME} env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get ${ARGN})
endfunction(go_extern)


function(generate_protobuf_cpp SRCS HDRS)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the documented convention at the beginning of this header file, what we need is proto_library. But the adding of proto_libary seems should be in a separate PR that fixes #2567.

set(PROTO_FILES ${ARGN})
set(${SRCS})
set(${HDRS})
foreach(FIL ${PROTO_FILES})
get_filename_component(ABS_FIL ${FIL} ABSOLUTE)
get_filename_component(FIL_WE ${FIL} NAME_WE)
if(NOT PROTOBUF_GENERATE_CPP_APPEND_PATH)
get_filename_component(FIL_DIR ${FIL} DIRECTORY)
if(FIL_DIR)
set(FIL_WE "${FIL_DIR}/${FIL_WE}")
endif()
endif()

list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.cc")
list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.h")

add_custom_command(
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.cc"
"${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.h"
COMMAND ${PROTOBUF_PROTOC_EXECUTABLE}
ARGS "--cpp_out=${DLL_EXPORT_DECL}${CMAKE_CURRENT_BINARY_DIR}" "-I" ${CMAKE_CURRENT_SOURCE_DIR} ${ABS_FIL}
DEPENDS ${ABS_FIL} protoc
COMMENT "Running C++ protocol buffer compiler on ${FIL}"
VERBATIM )
endforeach()
set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE)
set(${SRCS} ${${SRCS}} PARENT_SCOPE)
set(${HDRS} ${${HDRS}} PARENT_SCOPE)
endfunction()
11 changes: 10 additions & 1 deletion paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
cc_library(ddim SRCS ddim.cc)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)

nv_test(dim_test SRCS dim_test.cu DEPS ddim)

cc_test(variable_test SRCS variable_test.cc)
# include generated protobuf headers
INCLUDE_DIRECTORIES(${CMAKE_CURRENT_BINARY_DIR})
generate_protobuf_cpp(attr_proto_src attr_proto_hdr attr.proto)
cc_library(attr_proto SRCS ${attr_proto_src} DEPS protobuf)
generate_protobuf_cpp(attr_test_proto_src attr_test_proto_header attr_test.proto)
message(STATUS ${attr_test_proto_src})
cc_library(attr_helper SRCS attr_helper.cc DEPS attr_proto)
cc_test(attr_test SRCS ${attr_test_proto_src} attr_test.cc
DEPS attr_proto attr_helper protobuf)
cc_test(error_test SRCS error_test.cc)
17 changes: 17 additions & 0 deletions paddle/framework/attr.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
syntax="proto3";
package paddle.framework;

message Attribute {
message ListValue {
repeated int32 ints = 1;
repeated float floats = 2;
repeated string strings = 3;
}

oneof value {
ListValue list = 1;
int32 i = 2;
float f = 3;
string s = 4;
}
}
9 changes: 9 additions & 0 deletions paddle/framework/attr_helper.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#include "attr_helper.h"

namespace paddle {
namespace framework {

Error AttributeReader::NotFound("Attribute is not found");
Error AttributeReader::TypeMismatch("Type mismatched");
}
}
192 changes: 192 additions & 0 deletions paddle/framework/attr_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
#pragma once
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand what goal all the code here is trying to achieve. It might be easier if there is a design doc like paddle/framework/variable.md for paddle/framework/variable.h.

#include <google/protobuf/map.h>
#include <paddle/framework/error.h>
#include <iterator>
#include <string>
#include <type_traits>
#include "attr.pb.h"
namespace paddle {
namespace framework {
using AttributeMap = google::protobuf::Map<std::string, Attribute>;

class AttributeReader final {
public:
static Error NotFound;
static Error TypeMismatch;

explicit AttributeReader(const AttributeMap& attrs) : attrs_(attrs) {}

template <typename T>
inline Error __must_check Get(const std::string& attributeName,
T* attr) const;

template <typename T>
inline Error __must_check GetArray(const std::string& attributeName,
std::vector<T>* array) const;

private:
const AttributeMap& attrs_;
};

namespace details {
template <typename Iterator, typename T>
struct SetArrayImpl {
inline Error __must_check operator()(AttributeMap* attrs,
const std::string& attributeName,
Iterator begin, Iterator end,
bool overwrite);
};
} // namespace details

class AttributeWriter {
public:
explicit AttributeWriter(AttributeMap* attrs) : attrs_(attrs) {}

template <typename T>
inline Error __must_check Set(const std::string& attributeName, const T& attr,
bool overwrite = false);

template <typename Iterator>
inline Error __must_check SetArray(const std::string& attributeName,
Iterator begin, Iterator end,
bool overwrite = false) {
return details::SetArrayImpl<
Iterator, typename std::iterator_traits<Iterator>::value_type>()(
attrs_, attributeName, begin, end, overwrite);
}

template <typename T, typename Container = std::initializer_list<T>>
inline Error __must_check SetArray(const std::string& attributeName,
Container container,
bool overwrite = false) {
return SetArray(attributeName, container.begin(), container.end(),
overwrite);
}

private:
AttributeMap* attrs_;
};

#define ATTR_READER_IMPL_PLAIN_TYPE(T, CASE, FIELD_NAME) \
template <> \
inline Error __must_check AttributeReader::Get<T>( \
const std::string& attributeName, T* attr) const { \
auto it = attrs_.find(attributeName); \
if (it == attrs_.end()) { \
return NotFound; \
} \
if (it->second.value_case() != CASE) { \
return TypeMismatch; \
} \
*attr = it->second.FIELD_NAME(); \
return Error(); \
}

ATTR_READER_IMPL_PLAIN_TYPE(int, Attribute::kI, i);
ATTR_READER_IMPL_PLAIN_TYPE(float, Attribute::kF, f);
ATTR_READER_IMPL_PLAIN_TYPE(std::string, Attribute::kS, s);

#undef ATTR_READER_IMPL_PLAIN_TYPE

#define ATTR_READER_IMPL_ARRAY_TYPE(T, FIELD_NAME) \
template <> \
inline Error __must_check AttributeReader::GetArray<T>( \
const std::string& attributeName, std::vector<T>* array) const { \
if (!array->empty()) { \
return Error("The output array must be empty."); \
} \
\
auto it = attrs_.find(attributeName); \
if (it == attrs_.end()) { \
return NotFound; \
} \
\
auto& lst = it->second.list(); \
auto& field = lst.FIELD_NAME(); \
array->reserve(field.size()); \
std::copy(field.begin(), field.end(), std::back_inserter(*array)); \
return Error(); \
}

ATTR_READER_IMPL_ARRAY_TYPE(float, floats);
ATTR_READER_IMPL_ARRAY_TYPE(int, ints);
ATTR_READER_IMPL_ARRAY_TYPE(std::string, strings);

#undef ATTR_READER_IMPL_ARRAY_TYPE

#define ATTR_WRITER_IMPL_PLAIN_TYPE(T, FIELD_NAME) \
template <> \
inline Error __must_check AttributeWriter::Set<T>( \
const std::string& attributeName, const T& attr, bool overwrite) { \
auto it = attrs_->find(attributeName); \
if (it != attrs_->end() && !overwrite) { \
return Error("Attribute %s has been set", attributeName.c_str()); \
} \
(*attrs_)[attributeName].set_##FIELD_NAME(attr); \
return Error(); \
}

ATTR_WRITER_IMPL_PLAIN_TYPE(int, i);
ATTR_WRITER_IMPL_PLAIN_TYPE(float, f);
ATTR_WRITER_IMPL_PLAIN_TYPE(std::string, s);

#undef ATTR_WRITER_IMPL_PLAIN_TYPE

namespace details {
template <typename T>
inline void AppendToField(google::protobuf::RepeatedField<T>* field,
const T& val) {
field->Add(val);
}
template <typename T>
inline void AppendToField(google::protobuf::RepeatedPtrField<T>* field,
const T& val) {
*(field->Add()) = val;
}

} // namespace details

#define ATTR_WRITER_IMPL_ARRAY_TYPE(T, FIELD_NAME) \
namespace details { \
\
template <typename Iterator> \
struct SetArrayImpl<Iterator, T> { \
using VALUE_TYPE = typename std::iterator_traits<Iterator>::value_type; \
inline Error __must_check operator()(AttributeMap* attrs, \
const std::string& attributeName, \
Iterator begin, Iterator end, \
bool overwrite) { \
static_assert(std::is_same<VALUE_TYPE, T>::value, ""); \
auto it = attrs->find(attributeName); \
if (it != attrs->end() && !overwrite) { \
return Error("Attribute %s has been set", attributeName.c_str()); \
} \
\
if (it != attrs->end() && overwrite) { \
auto repeatedFieldPtr = \
it->second.mutable_list()->mutable_##FIELD_NAME(); \
repeatedFieldPtr->erase(repeatedFieldPtr->begin(), \
repeatedFieldPtr->end()); \
} \
auto lst = (*attrs)[attributeName].mutable_list(); \
auto elems = lst->mutable_##FIELD_NAME(); \
auto distance = std::distance(begin, end); \
if (std::is_integral<decltype(distance)>::value) { \
elems->Reserve(distance); \
} \
for (; begin != end; ++begin) { \
AppendToField(elems, *begin); \
} \
return Error(); \
} \
}; \
}

ATTR_WRITER_IMPL_ARRAY_TYPE(float, floats);
ATTR_WRITER_IMPL_ARRAY_TYPE(int, ints);
ATTR_WRITER_IMPL_ARRAY_TYPE(std::string, strings);

#undef ATTR_WRITER_IMPL_ARRAY_TYPE

} // namespace framework
} // namespace paddle
Loading