Skip to content

Commit

Permalink
[MXNET-128] added load from buffer functions (apache#10261)
Browse files Browse the repository at this point in the history
* rebased on master

* Update ndarray.h

* clairified comment

Mainly done to get to retest
  • Loading branch information
dabraude authored and cjolivier01 committed Apr 3, 2018
1 parent 47d0b58 commit a157d17
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
26 changes: 26 additions & 0 deletions cpp-package/include/mxnet-cpp/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,32 @@ class NDArray {
*/
static std::vector<NDArray> LoadToList(const std::string &file_name);
/*!
* \brief Load NDArrays from buffer.
* \param buffer Pointer to buffer. (ie contents of param file)
* \param size Size of buffer
* \param array_list a list of NDArrays returned, do not fill the list if
* nullptr is given.
* \param array_map a map from names to NDArrays returned, do not fill the map
* if nullptr is given or no names is stored in binary file.
*/
static void LoadFromBuffer(const void *buffer, size_t size,
std::vector<NDArray> *array_list = nullptr,
std::map<std::string, NDArray> *array_map = nullptr);
/*!
* \brief Load map of NDArrays from buffer.
* \param buffer Pointer to buffer. (ie contents of param file)
* \param size Size of buffer
* \return a list of NDArrays.
*/
static std::map<std::string, NDArray> LoadFromBufferToMap(const void *buffer, size_t size);
/*!
* \brief Load list of NDArrays from buffer.
* \param buffer Pointer to buffer. (ie contents of param file)
* \param size Size of buffer
* \return a map from names to NDArrays.
*/
static std::vector<NDArray> LoadFromBufferToList(const void *buffer, size_t size);
/*!
* \brief save a map of string->NDArray to binary file.
* \param file_name name of the binary file.
* \param array_map a map from names to NDArrays.
Expand Down
55 changes: 55 additions & 0 deletions cpp-package/include/mxnet-cpp/ndarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ inline void NDArray::Load(const std::string &file_name,
&out_names),
0);
if (array_list != nullptr) {
array_list->reserve(out_size);
for (mx_uint i = 0; i < out_size; ++i) {
array_list->push_back(NDArray(out_arr[i]));
}
Expand Down Expand Up @@ -291,6 +292,60 @@ inline std::vector<NDArray> NDArray::LoadToList(const std::string &file_name) {
CHECK_EQ(MXNDArrayLoad(file_name.c_str(), &out_size, &out_arr, &out_name_size,
&out_names),
0);
array_list.reserve(out_size);
for (mx_uint i = 0; i < out_size; ++i) {
array_list.push_back(NDArray(out_arr[i]));
}
return array_list;
}
inline void NDArray::LoadFromBuffer(const void *buffer, size_t size,
std::vector<NDArray> *array_list,
std::map<std::string, NDArray> *array_map) {
mx_uint out_size, out_name_size;
NDArrayHandle *out_arr;
const char **out_names;
CHECK_EQ(MXNDArrayLoadFromBuffer(buffer, size, &out_size, &out_arr, &out_name_size,
&out_names),
0);
if (array_list != nullptr) {
array_list->reserve(out_size);
for (mx_uint i = 0; i < out_size; ++i) {
array_list->push_back(NDArray(out_arr[i]));
}
}
if (array_map != nullptr && out_name_size > 0) {
CHECK_EQ(out_name_size, out_size);
for (mx_uint i = 0; i < out_size; ++i) {
(*array_map)[out_names[i]] = NDArray(out_arr[i]);
}
}
}
inline std::map<std::string, NDArray> NDArray::LoadFromBufferToMap(
const void *buffer, size_t size) {
std::map<std::string, NDArray> array_map;
mx_uint out_size, out_name_size;
NDArrayHandle *out_arr;
const char **out_names;
CHECK_EQ(MXNDArrayLoadFromBuffer(buffer, size, &out_size, &out_arr, &out_name_size,
&out_names),
0);
if (out_name_size > 0) {
CHECK_EQ(out_name_size, out_size);
for (mx_uint i = 0; i < out_size; ++i) {
array_map[out_names[i]] = NDArray(out_arr[i]);
}
}
return array_map;
}
inline std::vector<NDArray> NDArray::LoadFromBufferToList(const void *buffer, size_t size) {
std::vector<NDArray> array_list;
mx_uint out_size, out_name_size;
NDArrayHandle *out_arr;
const char **out_names;
CHECK_EQ(MXNDArrayLoadFromBuffer(buffer, size, &out_size, &out_arr, &out_name_size,
&out_names),
0);
array_list.reserve(out_size);
for (mx_uint i = 0; i < out_size; ++i) {
array_list.push_back(NDArray(out_arr[i]));
}
Expand Down

0 comments on commit a157d17

Please sign in to comment.