Skip to content

Commit

Permalink
added function for loading content of nd_array files from a buffer (a…
Browse files Browse the repository at this point in the history
…pache#9883)

* added function for loading content of nd_array files

* changed function name and added check for NULL

* removed no lint

* fixed whitespace

* corrected the casting

* added python wrapper for buffer loading

* added unit tests for loading from buffer

* whitespace fixes

* fix for python 3

* fixed test for py3

* python 3 problems

* fixed test

* switched to using temp files

* better use of temp files

* hopefully fixed permission issue

* removed specified directory

* hopefully this will work with windows

* fixed indentation

* check in to relaunch tests

Python 3 windows failed for no obvious reason, deleted some whitespace to relaunch

* switched to using temporary directory class

* removed unneeded imports

* moved imports to 1 location
  • Loading branch information
dabraude authored and cjolivier01 committed Feb 26, 2018
1 parent db24ac1 commit a352d1e
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 3 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,4 @@ List of Contributors
* [Tao Hu](https://github.com/dongzhuoyao)
* [Sorokin Evgeniy](https://github.com/TheTweak)
* [dwSun](https://github.com/dwSun/)
* [David Braude](https://github.com/dabraude/)
22 changes: 22 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,28 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
NDArrayHandle** out_arr,
mx_uint *out_name_size,
const char*** out_names);

/*!
* \brief Load list / dictionary of narrays from file content loaded into memory.
* This will load a list of ndarrays in a similar
* manner to MXNDArrayLoad, however, it loads from
* buffer containing the contents of a file, rather than
* from a specified file.
* \param ndarray_buffer pointer to the start of the ndarray file content
* \param size size of the file
* \param out_size number of narray loaded.
* \param out_arr head of the returning narray handles.
* \param out_name_size size of output name arrray.
* \param out_names the names of returning NDArrays, can be NULL
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayLoadFromBuffer(const void *ndarray_buffer,
size_t size,
mx_uint *out_size,
NDArrayHandle** out_arr,
mx_uint *out_name_size,
const char*** out_names);

/*!
* \brief Perform a synchronize copy from a continugous CPU memory region.
*
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/ndarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .op import *
from .ndarray import *
# pylint: enable=wildcard-import
from .utils import load, save, zeros, empty, array
from .utils import load, load_frombuffer, save, zeros, empty, array
from .sparse import _ndarray_cls
from .ndarray import _GRAD_REQ_MAP

Expand Down
39 changes: 38 additions & 1 deletion python/mxnet/ndarray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
except ImportError:
spsp = None

__all__ = ['zeros', 'empty', 'array', 'load', 'save']
__all__ = ['zeros', 'empty', 'array', 'load', 'load_frombuffer', 'save']


def zeros(shape, ctx=None, dtype=None, stype=None, **kwargs):
Expand Down Expand Up @@ -182,6 +182,43 @@ def load(fname):
for i in range(out_size.value))


def load_frombuffer(buf):
"""Loads an array dictionary or list from a buffer
See more details in ``save``.
Parameters
----------
buf : str
Buffer containing contents of a file as a string or bytes.
Returns
-------
list of NDArray, RowSparseNDArray or CSRNDArray, or \
dict of str to NDArray, RowSparseNDArray or CSRNDArray
Loaded data.
"""
if not isinstance(buf, string_types + tuple([bytes])):
raise TypeError('buf required to be a string or bytes')
out_size = mx_uint()
out_name_size = mx_uint()
handles = ctypes.POINTER(NDArrayHandle)()
names = ctypes.POINTER(ctypes.c_char_p)()
check_call(_LIB.MXNDArrayLoadFromBuffer(buf,
mx_uint(len(buf)),
ctypes.byref(out_size),
ctypes.byref(handles),
ctypes.byref(out_name_size),
ctypes.byref(names)))
if out_name_size.value == 0:
return [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(out_size.value)]
else:
assert out_name_size.value == out_size.value
return dict(
(py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i])))
for i in range(out_size.value))


def save(fname, data):
"""Saves a list of arrays or a dict of str->array to file.
Expand Down
34 changes: 34 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,40 @@ int MXNDArrayLoad(const char* fname,
API_END();
}

int MXNDArrayLoadFromBuffer(const void *ndarray_buffer,
size_t size,
mx_uint *out_size,
NDArrayHandle** out_arr,
mx_uint *out_name_size,
const char*** out_names) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
ret->ret_vec_str.clear();
API_BEGIN();
CHECK_NOTNULL(ndarray_buffer);
std::vector<NDArray> data;
std::vector<std::string> &names = ret->ret_vec_str;
{
std::unique_ptr<dmlc::MemoryFixedSizeStream> fi(new dmlc::MemoryFixedSizeStream(
const_cast<void*>(ndarray_buffer), size));
mxnet::NDArray::Load(fi.get(), &data, &names);
}
ret->ret_handles.resize(data.size());
for (size_t i = 0; i < data.size(); ++i) {
NDArray *ptr = new NDArray();
*ptr = data[i];
ret->ret_handles[i] = ptr;
}
ret->ret_vec_charp.resize(names.size());
for (size_t i = 0; i < names.size(); ++i) {
ret->ret_vec_charp[i] = names[i].c_str();
}
*out_size = static_cast<mx_uint>(data.size());
*out_arr = dmlc::BeginPtr(ret->ret_handles);
*out_name_size = static_cast<mx_uint>(names.size());
*out_names = dmlc::BeginPtr(ret->ret_vec_charp);
API_END();
}

int MXNDArrayFree(NDArrayHandle handle) {
API_BEGIN();
delete static_cast<NDArray*>(handle);
Expand Down
16 changes: 16 additions & 0 deletions tests/python/unittest/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
import mxnet as mx
import numpy as np
import random
import shutil
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.append(os.path.join(curr_path, '../common/'))
sys.path.insert(0, os.path.join(curr_path, '../../../python'))

import models
from contextlib import contextmanager
from nose.tools import make_decorator
import tempfile

def assertRaises(expected_exception, func, *args, **kwargs):
try:
Expand Down Expand Up @@ -225,3 +227,17 @@ def setup_module():
# the 'with_seed()' decoration. Inform the user of this once here at the module level.
if os.getenv('MXNET_TEST_SEED') is not None:
logger.warn('*** test-level seed set: all "@with_seed()" tests run deterministically ***')

try:
from tempfile import TemporaryDirectory
except:
# really simple implementation of TemporaryDirectory
class TemporaryDirectory(object):
def __init__(self, suffix='', prefix='', dir=''):
self._dirname = tempfile.mkdtemp(suffix, prefix, dir)

def __enter__(self):
return self._dirname

def __exit__(self, exc_type, exc_value, traceback):
shutil.rmtree(self._dirname)
48 changes: 47 additions & 1 deletion tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pickle as pkl
import unittest
from nose.tools import raises
from common import setup_module, with_seed
from common import setup_module, with_seed, assertRaises, TemporaryDirectory
from mxnet.test_utils import almost_equal
from mxnet.test_utils import assert_almost_equal
from mxnet.test_utils import default_context
Expand Down Expand Up @@ -291,6 +291,52 @@ def test_ndarray_legacy_load():
assert same(data[i].asnumpy(), legacy_data[i].asnumpy())


@with_seed()
def test_buffer_load():
nrepeat = 10
with TemporaryDirectory(prefix='test_buffer_load_') as tmpdir:
for repeat in range(nrepeat):
# test load_buffer as list
data = []
for i in range(10):
data.append(random_ndarray(np.random.randint(1, 5)))
fname = os.path.join(tmpdir, 'list_{0}.param'.format(repeat))
mx.nd.save(fname, data)
with open(fname, 'rb') as dfile:
buf_data = dfile.read()
data2 = mx.nd.load_frombuffer(buf_data)
assert len(data) == len(data2)
for x, y in zip(data, data2):
assert np.sum(x.asnumpy() != y.asnumpy()) == 0
# test garbage values
assertRaises(mx.base.MXNetError, mx.nd.load_frombuffer, buf_data[:-10])
# test load_buffer as dict
dmap = {'ndarray xx %s' % i : x for i, x in enumerate(data)}
fname = os.path.join(tmpdir, 'dict_{0}.param'.format(repeat))
mx.nd.save(fname, dmap)
with open(fname, 'rb') as dfile:
buf_dmap = dfile.read()
dmap2 = mx.nd.load_frombuffer(buf_dmap)
assert len(dmap2) == len(dmap)
for k, x in dmap.items():
y = dmap2[k]
assert np.sum(x.asnumpy() != y.asnumpy()) == 0
# test garbage values
assertRaises(mx.base.MXNetError, mx.nd.load_frombuffer, buf_dmap[:-10])

# we expect the single ndarray to be converted into a list containing the ndarray
single_ndarray = data[0]
fname = os.path.join(tmpdir, 'single_{0}.param'.format(repeat))
mx.nd.save(fname, single_ndarray)
with open(fname, 'rb') as dfile:
buf_single_ndarray = dfile.read()
single_ndarray_loaded = mx.nd.load_frombuffer(buf_single_ndarray)
assert len(single_ndarray_loaded) == 1
single_ndarray_loaded = single_ndarray_loaded[0]
assert np.sum(single_ndarray.asnumpy() != single_ndarray_loaded.asnumpy()) == 0
# test garbage values
assertRaises(mx.base.MXNetError, mx.nd.load_frombuffer, buf_single_ndarray[:-10])

@with_seed()
def test_ndarray_slice():
shape = (10,)
Expand Down

0 comments on commit a352d1e

Please sign in to comment.