Skip to content

Commit

Permalink
introduce CSRNDarray and rowsparseNDarray to python frontend api
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin committed Jun 1, 2017
1 parent cbfa792 commit 258c74b
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 90 deletions.
13 changes: 2 additions & 11 deletions python/mxnet/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .base import mx_uint, NDArrayHandle, ExecutorHandle
from .base import check_call, c_array, py_str
from .ndarray import NDArray
from .sparse_ndarray import SparseNDArray, _STORAGE_TYPE_STR_TO_ID
from .sparse_ndarray import _ndarray_cls
from . import ndarray as nd

# those functions are not used here, we just import them to keep backward compatibility
Expand Down Expand Up @@ -92,16 +92,7 @@ def _get_outputs(self):
check_call(_LIB.MXExecutorOutputs(self.handle,
ctypes.byref(out_size), ctypes.byref(handles)))
num_output = out_size.value
outputs = []
for i in range(num_output):
storage_type = ctypes.c_int(0)
check_call(_LIB.MXNDArrayGetStorageType(ctypes.cast(handles[i], NDArrayHandle),
ctypes.byref(storage_type)))
assert(storage_type != _STORAGE_TYPE_STR_TO_ID['undefined'])
output = NDArray(NDArrayHandle(handles[i])) \
if storage_type.value == _STORAGE_TYPE_STR_TO_ID['default'] \
else SparseNDArray(NDArrayHandle(handles[i]))
outputs.append(output)
outputs = [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(num_output)]
return outputs

def forward(self, is_train=False, **kwargs):
Expand Down
160 changes: 87 additions & 73 deletions python/mxnet/sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,15 @@
# pylint: disable=unused-import
try:
if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
#TODO remove some import?
from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke
from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class
elif _sys.version_info >= (3, 0):
from ._cy3.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke
from ._cy3.ndarray import NDArrayBase, _set_ndarray_class
else:
from ._cy2.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke
from ._cy2.ndarray import NDArrayBase, _set_ndarray_class
except ImportError:
if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0:
raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1")
from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke
from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class

# pylint: enable=unused-import
_STORAGE_AUX_TYPES = {
Expand Down Expand Up @@ -80,7 +79,8 @@ def _new_alloc_handle(storage_type, shape, ctx, delay_alloc, dtype, aux_types, a

class SparseNDArray(NDArray):
"""An array object representing a multidimensional, homogeneous array of
fixed-size items, stored in sparse format.
fixed-size items, stored in sparse format. See CSRNDArray and RowSparseNDArray
for more details.
"""

Expand Down Expand Up @@ -228,7 +228,8 @@ def _sync_copyfrom(self, source_array):

def _slice(self, start, stop):
"""Returns a read-only SparseNDArray slice that shares memory with current one.
To create a writable slice, please use ``mx.nd.slice`` instead.
To create a writable slice, please use ``mx.nd.slice`` instead. Currently only
`csr` storage type is supported.
Parameters
----------
Expand Down Expand Up @@ -263,7 +264,7 @@ def _slice(self, start, stop):
stop = mx_uint(stop) if stop else mx_uint(self.shape[0])

check_call(_LIB.MXNDArraySliceEx(self.handle, start, stop, handle))
ret = SparseNDArray(handle=handle, writable=False)
ret = _ndarray_cls(handle=handle, writable=False)
return ret

def _at(self, idx):
Expand All @@ -281,14 +282,14 @@ def _aux_type(self, i):
Returns
-------
numpy.dtype
This NDArray's data type.
This SparseNDArray's aux data type.
"""
aux_type = ctypes.c_int()
check_call(_LIB.MXNDArrayGetAuxType(self.handle, i, ctypes.byref(aux_type)))
return _DTYPE_MX_TO_NP[aux_type.value]

@property
def _values(self):
def values(self):
"""The values array of the SparseNDArray. This is a read-only view of the values array.
They reveal internal implementation details and should be used with care.
Expand All @@ -299,38 +300,6 @@ def _values(self):
"""
return self._data()

@property
def _indices(self):
"""The indices array of the SparseNDArray. This is a read-only view of the indices array.
They reveal internal implementation details and should be used with care.
Returns
-------
NDArray
This SparseNDArray's indices array.
"""
stype = self.storage_type
if stype == 'row_sparse':
return self._aux_data(0)
elif stype == 'csr':
return self._aux_data(1)
raise Exception("unknown storage type " + stype)

@property
def _indptr(self):
"""The indptr array of the SparseNDArray with `csr` storage type.
This is a read-only view of the indptr array.
They reveal internal implementation details and should be used with care.
Returns
-------
NDArray
This SparseNDArray's indptr array.
"""
stype = self.storage_type
if stype == 'csr':
return self._aux_data(0)
raise Exception("unknown storage type " + stype)

@property
def _num_aux(self):
Expand Down Expand Up @@ -389,7 +358,7 @@ def copyto(self, other):
return
return _internal._copyto(self, out=other)
elif isinstance(other, Context):
hret = SparseNDArray(_new_alloc_handle(self.storage_type, self.shape, other,
hret = _ndarray_cls(_new_alloc_handle(self.storage_type, self.shape, other,
True, self.dtype, self.aux_types))
return _internal._copyto(self, out=hret)
else:
Expand All @@ -414,6 +383,66 @@ def _data(self, writable=False):
check_call(_LIB.MXNDArrayGetDataNDArray(self.handle, ctypes.byref(hdl)))
return NDArray(hdl, writable)

class CSRNDArray(SparseNDArray):
"""A CSRNDArray represents a NDArray as three separate arrays: `values`,
`indptr` and `indices`. It uses the standard CSR representation where the column indices for
row i are stored in indices[indptr[i]:indptr[i+1]] and their corresponding values are stored
in values[indptr[i]:indptr[i+1]].
"""

@property
def indices(self):
"""The indices array of the SparseNDArray. This is a read-only view of the indices array.
They reveal internal implementation details and should be used with care.
Returns
-------
NDArray
This SparseNDArray's indices array.
"""
return self._aux_data(1)

@property
def indptr(self):
"""The indptr array of the SparseNDArray with `csr` storage type.
This is a read-only view of the indptr array.
They reveal internal implementation details and should be used with care.
Returns
-------
NDArray
This SparseNDArray's indptr array.
"""
return self._aux_data(0)

class RowSparseNDArray(SparseNDArray):
"""A RowSparseNDArray is typically used to represent a subset of a larger
NDArray with `default` of shape [LARGE0, D1, .. , DN] where LARGE0 >> D0. The values
in indices are the indices in the first dimension of the slices that have been extracted from
the larger NDArray. The indices are expected to be sorted in ascending order.
The corresponding NDArray ``dense`` with `default` storage represented by a ``rsp``
RowSparseNDArray
``dense[rsp.indices[i], :, :, :, ...] = rsp.values[i, :, :, :, ...]``
RowSparseNDArray is used principally in the definition of gradients for operations
that have sparse gradients (e.g. SparseEmbedding).
"""

@property
def indices(self):
"""The indices array of the SparseNDArray. This is a read-only view of the indices array.
They reveal internal implementation details and should be used with care.
Returns
-------
NDArray
This SparseNDArray's indices array.
"""
return self._aux_data(0)

def _prepare_src_array(src, dtype, default_dtype):
if isinstance(src, NDArray):
dtype = src.dtype if dtype is None else dtype
Expand All @@ -429,11 +458,6 @@ def _prepare_src_array(src, dtype, default_dtype):
def csr(values, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, indices_type=None):
"""Creates a 2D array with compressed sparse row format.
A SparseNDArray with `csr` storage represents a NDArray as three separate arrays: `values`,
`indptr` and `indices`. It uses the standard CSR representation where the column indices for
row i are stored in indices[indptr[i]:indptr[i+1]] and their corresponding values are stored
in values[indptr[i]:indptr[i+1]].
Parameters
----------
values: array_like
Expand All @@ -458,8 +482,8 @@ def csr(values, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None,
Returns
-------
SparseNDArray
An `SparseNDArray` with the `csr` storage representation.
CSRNDArray
A `CSRNDArray` with the `csr` storage representation.
"""
storage_type = 'csr'
# context
Expand All @@ -480,7 +504,7 @@ def csr(values, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None,
assert(indptr.ndim == 1)
assert(indices.ndim == 1)
assert(len(shape) == 2)
result = SparseNDArray(_new_alloc_handle(storage_type, shape, ctx, False, dtype,
result = CSRNDArray(_new_alloc_handle(storage_type, shape, ctx, False, dtype,
[indptr_type, indices_type], aux_shapes))
# assign indptr, indices and values
values_ref = result._data(True)
Expand All @@ -494,19 +518,6 @@ def csr(values, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None,
def row_sparse(values, indices, shape, ctx=None, dtype=None, indices_type=None):
"""Creates a row sparse array with a set of tensor slices at given indices.
A SparseNDArray with `row_sparse` storage is typically used to represent a subset of a larger
NDArray with `default` of shape [LARGE0, D1, .. , DN] where LARGE0 >> D0. The values
in indices are the indices in the first dimension of the slices that have been extracted from
the larger NDArray. The indices are expected to be sorted in ascending order.
The corresponding NDArray ``dense`` with `default` represented by a ``rsp``
SparseNDArray with `row_sparse` storage has
``dense[rsp.indices[i], :, :, :, ...] = rsp.values[i, :, :, :, ...]``
`row_sparse` SparseNDArray is used principally in the definition of gradients for operations
that have sparse gradients (e.g. SparseEmbedding).
Parameters
----------
values: array_like
Expand All @@ -525,8 +536,8 @@ def row_sparse(values, indices, shape, ctx=None, dtype=None, indices_type=None):
Returns
-------
SparseNDArray
An `SparseNDArray` with the `row_sparse` storage representation.
RowSparseNDArray
An `RowSparseNDArray` with the `row_sparse` storage representation.
"""
storage_type = 'row_sparse'
# context
Expand All @@ -541,7 +552,7 @@ def row_sparse(values, indices, shape, ctx=None, dtype=None, indices_type=None):
# verify shapes
assert(values.ndim == len(shape))
assert(indices.ndim == 1)
result = SparseNDArray(_new_alloc_handle(storage_type, shape, ctx, False, dtype,
result = RowSparseNDArray(_new_alloc_handle(storage_type, shape, ctx, False, dtype,
[indices_type], [indices.shape]))
# assign indices and values
values_ref = result._data(True)
Expand All @@ -555,7 +566,7 @@ def to_dense(source):
Returns
-------
SparseNDArray
NDArray
The dense array with default storage
"""
return ndarray.cast_storage(source, storage_type='default')
Expand Down Expand Up @@ -597,16 +608,19 @@ def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None):
else:
raise Exception("unknown storage type")
assert(len(aux_types) == len(_STORAGE_AUX_TYPES[storage_type]))
out = SparseNDArray(_new_alloc_handle(storage_type, shape, ctx, True, dtype, aux_types))
out = _ndarray_cls(_new_alloc_handle(storage_type, shape, ctx, True, dtype, aux_types))
return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, out=out)

def _ndarray_cls(handle, writable=True):
stype = _storage_type(handle)
# TODO(haibin) in the long run, we want to have CSRNDArray and RowSparseNDArray which
# inherit from SparseNDArray
if stype == 'default':
return NDArray(handle, writable)
return SparseNDArray(handle, writable)
return NDArray(handle, writable=writable)
elif stype == 'csr':
return CSRNDArray(handle, writable=writable)
elif stype == 'row_sparse':
return RowSparseNDArray(handle, writable=writable)
else:
raise Exception("unknown storage type")

# pylint: enable=too-many-locals, invalid-name
def _init_ndarray_module(ndarray_class, root_namespace):
Expand Down
12 changes: 6 additions & 6 deletions tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def check_sparse_nd_prop_rsp():
shape = rand_shape_2d()
nd, (v, idx) = rand_sparse_ndarray(shape, storage_type)
assert(nd._num_aux == 1)
assert(nd._indices.dtype == np.int32)
assert(nd.indices.dtype == np.int32)
assert(nd.storage_type == 'row_sparse')
assert_almost_equal(nd._indices.asnumpy(), idx)
assert_almost_equal(nd.indices.asnumpy(), idx)

def test_sparse_nd_basic():
def check_rsp_creation(values, indices, shape):
Expand All @@ -98,13 +98,13 @@ def check_rsp_creation(values, indices, shape):
dns[3] = mx.nd.array(values[1])
assert_almost_equal(rsp.asnumpy(), dns.asnumpy())
indices = mx.nd.array(indices).asnumpy()
assert_almost_equal(rsp._indices.asnumpy(), indices)
assert_almost_equal(rsp.indices.asnumpy(), indices)

def check_csr_creation(shape):
csr, (indptr, indices, values) = rand_sparse_ndarray(shape, 'csr')
assert_almost_equal(csr._indptr.asnumpy(), indptr)
assert_almost_equal(csr._indices.asnumpy(), indices)
assert_almost_equal(csr._values.asnumpy(), values)
assert_almost_equal(csr.indptr.asnumpy(), indptr)
assert_almost_equal(csr.indices.asnumpy(), indices)
assert_almost_equal(csr.values.asnumpy(), values)

shape = (4,2)
values = np.random.rand(2,2)
Expand Down

0 comments on commit 258c74b

Please sign in to comment.