Skip to content

Commit

Permalink
Operators for sum(csr, axis=0) and sum(csr, axis=1) (#8174)
Browse files Browse the repository at this point in the history
* Add Infer storage for sparse slice operator

* Remove unused files

* Indentation fix and add gpu test for fallback

* Change sum builtin to py_sum

* Add sum_axis(csr,axis=0)=dense and sum(csr,axis=1)=dense operator

* Documentation changes for sparse

* Add fallback unittest for keepdims and exclude

* PR review based changes
:

* Fix CHECK_NE

* Change in_stype to int

* Using const int instead of int

* Initialize mid with the start
  • Loading branch information
anirudh2290 authored and piiswrong committed Oct 13, 2017
1 parent 3b36217 commit 46ec178
Show file tree
Hide file tree
Showing 5 changed files with 307 additions and 3 deletions.
4 changes: 3 additions & 1 deletion python/mxnet/ndarray/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
from __future__ import division
try:
from __builtin__ import slice as py_slice
from __builtin__ import sum as py_sum
except ImportError:
from builtins import slice as py_slice
from builtins import sum as py_sum

import ctypes
import warnings
Expand Down Expand Up @@ -94,7 +96,7 @@ def _new_alloc_handle(stype, shape, ctx, delay_alloc, dtype, aux_types, aux_shap
aux_type_ids = [int(_DTYPE_NP_TO_MX[np.dtype(aux_t).type]) for aux_t in aux_types]
aux_shapes = [(0,) for aux_t in aux_types] if aux_shapes is None else aux_shapes
aux_shape_lens = [len(aux_shape) for aux_shape in aux_shapes]
aux_shapes = sum(aux_shapes, ())
aux_shapes = py_sum(aux_shapes, ())
num_aux = mx_uint(len(aux_types))
check_call(_LIB.MXNDArrayCreateSparseEx(
ctypes.c_int(int(_STORAGE_TYPE_STR_TO_ID[stype])),
Expand Down
259 changes: 259 additions & 0 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,49 @@ inline void BroadcastReduceShapeCompact(const TShape& big, const TShape& small,
}
}

inline bool SumOpForwardInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
const int in_stype = in_attrs->at(0);
int& out_stype = out_attrs->at(0);
bool dispatched = false;
// sum only supported for CPU for now. TODO: Remove when support for GPU added
const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
const auto dispatch_ex =
invalid_ctx ? DispatchMode::kFComputeFallback : DispatchMode::kFComputeEx;
if (!dispatched && in_stype == kDefaultStorage) {
// When input is dense output storage is set as dense and dispatched to
// dense operator
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
DispatchMode::kFCompute);
}

if (!dispatched && in_stype == kCSRStorage &&
(param.axis[0] == 0 || param.axis[0] == 1) && !param.keepdims &&
!param.exclude) {
// If input is csr and axis is 0 or 1, and neither of keepdims or exclude
// are set, dipsatch to sparse operator and output storage is set as dense
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
dispatch_ex);
}

if (!dispatched) {
// If input is csr, but keepdims or exclude is set or summing along a axis
// different from 0 or 1
dispatch_fallback(out_attrs, dispatch_mode);
}
if (*dispatch_mode == DispatchMode::kFComputeFallback) {
LogStorageFallback(attrs, dev_mask, in_attrs, out_attrs);
}

return true;
}

template<typename xpu, typename reducer>
void SearchAxisCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down Expand Up @@ -411,6 +454,222 @@ void ReduceAxesCompute(const nnvm::NodeAttrs& attrs,
ReduceAxesComputeImpl<xpu, reducer, normalize>(attrs, ctx, inputs, req, outputs, small);
}

template <int req, int axis>
struct SumCsrKernel;

template <int req>
/* \brief The number of columns are divided equally among the number of threads
* available.
* Each thread gets a subset of columns. It iterates through all rows for the
* subset of columns.
* In each iteration, it tries to do a binary search for the first column
* index between in_idx[in_indptr[row]] in_idx[in_indptr[row+1]]. After we find
* an index that is equal to the first column or close to the first column,
* it does a linear search for the rest of the indices and adds their data
* to the intermediate sum. At the end of iteration through all
* rows we have the sum along the axis for the subset of columns.
*/
struct SumCsrKernel<req, 0> {
template <typename RType, typename IType, typename DType>
MSHADOW_XINLINE static void Map(int j, DType* out_data,
const RType* in_indptr, const IType* in_idx,
const DType* in_data,
DType* sum,
DType* residual,
RType num_rows,
IType num_cols,
const nnvm::dim_t seg_len) {
const IType seg_start = j * seg_len;
if (seg_start >= num_cols) return;
const IType seg_end = std::min(seg_start + seg_len, num_cols);

for (RType row = 0; row < num_rows; ++row) {
// row specific seg starts
IType row_seg_start = seg_start;
IType row_seg_end = seg_end;

// Cache starting and ending indptr values for the row
IType row_indptr_start = in_indptr[row];
IType row_indptr_end = in_indptr[row + 1] - 1;
if (row_indptr_start == (row_indptr_end + 1)) continue;

// If row_seg_start is less than the first index for the row, move the
// row_seg_start forward
while (row_seg_start < in_idx[row_indptr_start] &&
row_seg_start < row_seg_end) {
row_seg_start++;
}

// If row_seg_start is greater than last index for the row, move on to
// the next row
if (row_seg_start > in_idx[row_indptr_end]) continue;

// Do binary search for row_seg_start between in_idx[in_indptr[i]] and
// in_idx[in_indptr[i + 1]]
IType start = row_indptr_start;
IType end = row_indptr_end;

// Initialize mid with the first indice of the row
IType mid = start;
while (start <= end) {
mid = start + (end - start) / 2;
if (in_idx[mid] == row_seg_start) {
break;
} else if (in_idx[mid] < row_seg_start) {
start = mid + 1;
} else {
end = mid - 1;
}
}

// At this point we have a in_idx[mid] which is close to row_seg_start
// Safety check to make sure mid is a valid indptr value
if (mid < row_indptr_start || mid > row_indptr_end)
mid = row_indptr_start;


// Linear search for nnzs for column subset between row_seg_start
// and row_seg_end
for (IType col = row_seg_start;
col < row_seg_end && mid <= row_indptr_end;) {
if (col == in_idx[mid]) {
mshadow::red::sum::Reduce(sum[col], in_data[mid],
residual[col]);
mid++;
col++;
} else if (in_idx[mid] < col) {
mid++;
} else {
col++;
}
}
}

for (IType col = seg_start; col < seg_end; col++) {
KERNEL_ASSIGN(out_data[col], req, sum[col]);
}
}
};

template <int req>
struct SumCsrKernel<req, 1> {
template <typename RType, typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
const RType* in_indptr,
const DType* in_data) {
DType sum, residual;
mshadow::red::sum::SetInitValue(sum, residual);
for (RType k = in_indptr[i]; k < in_indptr[i + 1]; k++) {
mshadow::red::sum::Reduce(sum, in_data[k], residual);
}
KERNEL_ASSIGN(out_data[i], req, sum);
}
};

template <typename xpu>
void SumCsrImpl(const nnvm::NodeAttrs& attrs, mshadow::Stream<xpu>* s, const OpContext& ctx,
const NDArray& input, const OpReqType req, NDArray* output) {
if (req == kNullOp) return;
const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
CHECK_EQ(param.axis.ndim(), 1U) << "sum(csr) only supports axis 0 or 1";
CHECK(param.axis[0] == 0 || param.axis[0] == 1)
<< "sum(csr) only support axis 0 or 1";
CHECK(!param.keepdims) << "keepdims not supported for sparse";
CHECK(!param.exclude) << "exclude not supported for sparse";
int64_t out_data_size = 0;
if (param.axis[0] == 0) {
out_data_size = input.shape()[1];
} else {
out_data_size = input.shape()[0];
}
// only dense output storage type is supported
CHECK_EQ(output->storage_type(), kDefaultStorage);

using namespace mshadow;
using namespace mxnet_op;
using namespace csr;
using nnvm::dim_t;

if (req == kWriteTo || req == kWriteInplace) {
MSHADOW_TYPE_SWITCH(output->data().type_flag_, DType, {
Kernel<set_zero, xpu>::Launch(s, out_data_size,
output->data().dptr<DType>());
})
}

if (!input.storage_initialized()) {
return;
}

if (0 == param.axis[0]) {
MSHADOW_IDX_TYPE_SWITCH(input.aux_type(kIndPtr), RType, {
MSHADOW_IDX_TYPE_SWITCH(input.aux_type(kIdx), IType, {
MSHADOW_TYPE_SWITCH(input.dtype(), DType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
const RType* in_indptr = input.aux_data(kIndPtr).dptr<RType>();
const IType* in_idx = input.aux_data(kIdx).dptr<IType>();
const DType* in_data = input.data().dptr<DType>();
const RType num_rows = input.shape()[0];
const IType num_cols = input.shape()[1];
dim_t num_threads = mxnet_op::get_num_threads<xpu>(16);
dim_t seg_len = (out_data_size + num_threads - 1) / num_threads;
mshadow::Tensor<xpu, 1, DType> workspace =
ctx.requested[0].get_space_typed<xpu, 1, DType>(
Shape1(2 * out_data_size), s);
mshadow::Tensor<xpu, 1, DType> sum(
reinterpret_cast<DType*>(workspace.dptr_),
Shape1(out_data_size));
mshadow::Tensor<xpu, 1, DType> residual(
reinterpret_cast<DType*>(workspace.dptr_ +
out_data_size),
Shape1(out_data_size));

Kernel<set_zero, xpu>::Launch(s, out_data_size, sum.dptr_);
Kernel<set_zero, xpu>::Launch(s, out_data_size, residual.dptr_);
Kernel<SumCsrKernel<req_type, 0>, xpu>::Launch(
s, num_threads, output->data().dptr<DType>(), in_indptr, in_idx,
in_data, sum.dptr_, residual.dptr_, num_rows, num_cols,
seg_len);
});
});
});
});
} else if (1 == param.axis[0]) {
MSHADOW_IDX_TYPE_SWITCH(input.aux_type(kIndPtr), RType, {
MSHADOW_TYPE_SWITCH(input.dtype(), DType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
const RType* in_indptr = input.aux_data(kIndPtr).dptr<RType>();
const DType* in_data = input.data().dptr<DType>();
Kernel<SumCsrKernel<req_type, 1>, xpu>::Launch(
s, out_data_size, output->data().dptr<DType>(), in_indptr,
in_data);
});
});
});
}
}

template <typename xpu, typename reducer, bool normalize = false>
void SumOpForwardEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
const NDArrayStorageType istype = inputs[0].storage_type();
if (istype == kCSRStorage) {
CHECK_EQ(inputs[0].shape().ndim(), 2U)
<< "sum(csr) op only supports 2D ndarray as input";
NDArray output = outputs[0];
SumCsrImpl(attrs, s, ctx, inputs[0], req[0], &output);
} else {
LOG(FATAL) << "Not implemented: "
<< operator_string(attrs, ctx, inputs, req, outputs);
}
}

// works when shape inference of output is given
template<typename xpu, typename OP, bool normalize = false>
void ReduceAxesBackwardUseInOut(const nnvm::NodeAttrs& attrs,
Expand Down
17 changes: 17 additions & 0 deletions src/operator/tensor/broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,15 @@ Defined in )code";
}

MXNET_OPERATOR_REGISTER_REDUCE(sum)
MXNET_ADD_SPARSE_OP_ALIAS(sum)
.add_alias("sum_axis")
.describe(R"code(Computes the sum of array elements over given axes.
.. Note::
`sum` and `sum_axis` are equivalent.
For ndarray of csr storage type summation along axis 0 and axis 1 is supported.
Setting keepdims or exclude to True will cause a fallback to dense operator.
Example::
Expand All @@ -66,8 +69,22 @@ Example::
sum(data, axis=[1,2])
[ 12. 19. 27.]
data = [[1,2,0],
[3,0,1],
[4,1,0]]
csr = cast_storage(data, 'csr')
sum(csr, axis=0)
[ 8. 2. 2.]
sum(csr, axis=1)
[ 3. 4. 5.]
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", ReduceAxesCompute<cpu, mshadow::red::sum>)
.set_attr<FComputeEx>("FComputeEx<cpu>", SumOpForwardEx<cpu, mshadow::red::sum>)
.set_attr<FInferStorageType>("FInferStorageType", SumOpForwardInferStorageType)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/cast_storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ Example::
[ 0., 0., 0.]]
# cast to row_sparse storage type
rsp = cast_storage(default, 'row_sparse')
rsp = cast_storage(dense, 'row_sparse')
rsp.indices = [0, 1]
rsp.values = [[ 0., 1., 0.],
[ 2., 0., 3.]]
# cast to csr storage type
csr = cast_storage(default, 'csr')
csr = cast_storage(dense, 'csr')
csr.indices = [1, 0, 2]
csr.values = [ 1., 2., 3.]
csr.indptr = [0, 1, 3, 3, 3]
Expand Down
26 changes: 26 additions & 0 deletions tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,32 @@ def check_sparse_nd_zeros_like(stype, shape):
check_sparse_nd_zeros_like('row_sparse', shape)
check_sparse_nd_zeros_like('csr', shape)

def test_sparse_sum_axis():
def test_variations():
dim0 = 30
dim1 = 100
axes = [0, 1]
densities = [0, 0.5, 1]
for density in densities:
shape = rand_shape_2d(dim0, dim1)
csr_array = rand_ndarray(shape=shape, stype='csr', density=density)
dns = csr_array.tostype('default')
for axis in axes:
ret = mx.nd.sum(csr_array, axis=axis)
assert ret.stype == 'default'
ret_expected = mx.nd.sum(dns, axis=axis)
assert_almost_equal(ret.asnumpy(), ret_expected.asnumpy())

def test_fallback(axis=0, keepdims=True, exclude=True):
dim0 = 30
dim1 = 100
shape = rand_shape_2d(dim0, dim1)
csr_array = rand_ndarray(shape=shape, stype='csr', density=0.01)
ret = mx.nd.sum(csr_array, axis=axis, keepdims=keepdims,
exclude=exclude)

test_variations()
test_fallback(axis=0, keepdims=True, exclude=True)

def test_sparse_square_sum():
if default_context().device_type == 'cpu':
Expand Down

0 comments on commit 46ec178

Please sign in to comment.