Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[test] Sparse mega pr #168

Merged
merged 66 commits into from
Aug 16, 2017
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
c5f6648
[WIP] Sparse Tensor (#5800)
eric-haibin-lin Jun 26, 2017
db65770
move storage type vector from nnvm to mxnet (#7054)
eric-haibin-lin Jul 15, 2017
e2607da
fix failed tests. add back 64bit support for dot
eric-haibin-lin Jul 17, 2017
978748e
Improve copy sparse tensors (#7003)
reminisce Jul 15, 2017
ce0fec8
bug fix for IdentityComputeRsp
eric-haibin-lin Jul 18, 2017
a2b3d3e
fix lint
eric-haibin-lin Jul 18, 2017
27c9ac0
add data partition for libsvm iter (#7027)
eric-haibin-lin Jul 21, 2017
3a394ea
fix ndarray namespace
eric-haibin-lin Jul 22, 2017
cf61a9e
remove sparse embedding (#7165)
eric-haibin-lin Jul 23, 2017
fe62976
remove untested gpu operators (#7172)
eric-haibin-lin Jul 24, 2017
4de0fdd
Fix ndarray aux data issue (#7098)
reminisce Jul 25, 2017
a472b61
Support K-dimensional row-sparse tensor (#7179)
eric-haibin-lin Jul 25, 2017
6a01b6e
Improve sparse ndarray error message (#7181)
eric-haibin-lin Jul 25, 2017
05ddf38
construct row_sparse ndarray for dist-async
eric-haibin-lin Jun 26, 2017
f57fc3c
Merge remote-tracking branch 'upstream/master' into dmlc-sparse-squash
eric-haibin-lin Jul 26, 2017
0ed14d1
fix DotCsrRspRspImpl error message (#7191)
stefanhenneking Jul 26, 2017
f0af872
GPU implementation of cast_storage (dense to csr) (#7081)
stefanhenneking Jul 27, 2017
6f0719f
Sparse square sum (#7206)
reminisce Jul 27, 2017
ec2c4bf
Modify and Add documentation for mx.nd.zeros (#7197)
anirudh2290 Jul 27, 2017
88eaac6
Merge remote-tracking branch 'upstream/master' into dmlc-sparse-squash
eric-haibin-lin Jul 27, 2017
3b94a3c
Expose kWriteInplace for imperative execution (fcompute_ex and fstate…
eric-haibin-lin Jul 28, 2017
55e4763
Operator add_n for row sparse ndarrays (#7244)
reminisce Aug 1, 2017
7e1647c
GPU implementation of cast_storage (dense to rsp) (#7223)
stefanhenneking Aug 1, 2017
5905ddc
merge with dmlc/master
eric-haibin-lin Aug 2, 2017
d8a9aba
resolve merge conflict in ndarray.load
eric-haibin-lin Aug 2, 2017
f686174
Improve StatefulOp/FCompute storage fallback (#134)
eric-haibin-lin Aug 2, 2017
d0579c4
update sparse ndarray api (#139)
eric-haibin-lin Aug 3, 2017
56b5a63
Merge remote-tracking branch 'upstream/master' into sparse
eric-haibin-lin Aug 3, 2017
325f4db
Handle ograd_stype='row_sparse' for square_sum backward (#143)
reminisce Aug 3, 2017
5866b2b
Sparse retain improvement (#138)
reminisce Aug 5, 2017
9298bfa
ignoring variables in SimpleBind that is used on python's sparse bran…
sergeykolychev Aug 5, 2017
1f07771
add bias term to fm test (#145)
eric-haibin-lin Aug 5, 2017
d511938
merge with upstream/master. resolve conflict in c_api_ndarray.cc
eric-haibin-lin Aug 5, 2017
6956431
update ndarray.nd, remove `invoke` from excluded members (#137)
eric-haibin-lin Aug 6, 2017
6c9a350
Merge remote-tracking branch 'upstream/master' into sparse
eric-haibin-lin Aug 6, 2017
66b7b8a
support storage fallback with mutable inputs (#147)
eric-haibin-lin Aug 6, 2017
cf8ddcf
Merge branch 'sparse' of https://github.com/eric-haibin-lin/mxnet int…
eric-haibin-lin Aug 6, 2017
0396c9a
Merge remote-tracking branch 'upstream/master' into sparse
eric-haibin-lin Aug 7, 2017
2dc7dc9
Code changes based on reviews (#144)
eric-haibin-lin Aug 8, 2017
f318c9d
small edits according to reviews (#151)
eric-haibin-lin Aug 8, 2017
85cbc60
Merge remote-tracking branch 'upstream/master' into sparse
eric-haibin-lin Aug 8, 2017
fc1aa6e
fix lint (#152)
eric-haibin-lin Aug 8, 2017
9ba96b9
resolve conflict in ndarray.py and capi
eric-haibin-lin Aug 8, 2017
6cbdf98
resolve conflicts in license header
eric-haibin-lin Aug 8, 2017
253ae57
add license to all new files in sparse brnach (#154)
eric-haibin-lin Aug 9, 2017
b2ad302
Allocate temp data on the fly for some casting operations (#149)
cjolivier01 Aug 9, 2017
129148c
Merge remote-tracking branch 'upstream/master' into sparse
eric-haibin-lin Aug 9, 2017
d6f987d
fix utf8 encoding in sparse ndarray
eric-haibin-lin Aug 9, 2017
955e97f
Merge branch 'sparse' of https://github.com/eric-haibin-lin/mxnet int…
eric-haibin-lin Aug 9, 2017
bc33101
Extending the GPU dot operator (#7226)
stefanhenneking Aug 9, 2017
8040953
Merge remote-tracking branch 'upstream/master' into sparse
eric-haibin-lin Aug 9, 2017
2d93d72
Add get_synthetic_dataset function to util (#146)
anirudh2290 Aug 10, 2017
80a590d
temporary fix for batch norm storage fallback (#156)
eric-haibin-lin Aug 10, 2017
92f54d2
support random_uniform/normal/gamma with row_sparse output (#155)
eric-haibin-lin Aug 10, 2017
17bfa4e
Merge remote-tracking branch 'upstream/master' into sparse
eric-haibin-lin Aug 10, 2017
ef3b442
Merge remote-tracking branch 'upstream/master' into sparse
eric-haibin-lin Aug 10, 2017
a44afed
Square sum backward support one more case (#161)
reminisce Aug 10, 2017
ceca9b6
Add documentation for sparse ops (#148)
eric-haibin-lin Aug 11, 2017
1c60a05
A few fixes (#163)
eric-haibin-lin Aug 11, 2017
04e9129
Merge branch 'sparse' of https://github.com/eric-haibin-lin/mxnet int…
eric-haibin-lin Aug 12, 2017
8ebc012
merge with upstream/master
eric-haibin-lin Aug 12, 2017
889a09e
Minor fixes sparse ops (#160)
stefanhenneking Aug 14, 2017
6b0cac1
sparse Adam optimizer (#164)
eric-haibin-lin Aug 14, 2017
eeff444
kvstore.row_sparse_pull for GPU and end-to-end benchmark: CPU vs. mul…
reminisce Aug 15, 2017
54f698b
fix bug in adam update (#167)
eric-haibin-lin Aug 15, 2017
6fa078e
change sparse example from regression to classification (#165)
eric-haibin-lin Aug 15, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions docs/api/python/ndarray.md
Original file line number Diff line number Diff line change
Expand Up @@ -547,11 +547,6 @@ The `contrib.ndarray` module contains many useful experimental APIs for new feat
:members:
:special-members:

.. autoclass:: mxnet.ndarray.BaseSparseNDArray
:members:
:special-members:
:exclude-members: __weakref__

.. autoclass:: mxnet.ndarray.CSRNDArray
:members:
:special-members:
Expand Down
4 changes: 4 additions & 0 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,10 @@ def backward(self, out_grad=None, retain_graph=False, train_mode=True):
def tostype(self, stype):
"""Return a copy of the array with chosen storage type.

See Also
----------
:meth:`mxnet.ndarray.cast_storage`.

Returns
-------
NDArray, CSRNDArray or RowSparseNDArray
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/ndarray/sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,8 +846,8 @@ def _zeros_sparse_ndarray(stype, shape, ctx=None, dtype=None, aux_types=None, **
dtype : str or numpy.dtype, optional
An optional value type (default is `float32`)
aux_types: list of numpy.dtype, optional
An optional type for the aux data for BaseSparseNDArray (default values depends
on the storage type)
An optional list of types of the aux data for RowSparseNDArray or CSRNDArray
(default values depends on the storage type)

Returns
-------
Expand Down
14 changes: 7 additions & 7 deletions python/mxnet/ndarray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def zeros(shape, ctx=None, dtype=None, stype=None, aux_types=None, **kwargs):
dtype : str or numpy.dtype, optional
An optional value type (default is `float32`)
stype: string, optional
The storage type of the empty array, such as 'row_sparse', 'csr', etc
The storage type of the empty array, such as 'row_sparse', 'csr', etc.
aux_types: list of numpy.dtype, optional
An optional type for the aux data for the BaseSparseNDArray (default values
depends on the storage type)
An optional list of types of the aux data for RowSparseNDArray or CSRNDArray
(default values depend on the storage type)

Returns
-------
Expand Down Expand Up @@ -73,8 +73,8 @@ def empty(shape, ctx=None, dtype=None, stype=None, aux_types=None):
stype : str, optional
An optional storage type (default is `default`).
aux_types: list of numpy.dtype, optional
An optional type for the aux data for the BaseSparseNDArray (default values depends
on the storage type)
An optional list of types of the aux data for RowSparseNDArray or CSRNDArray
(default values depend on the storage type)

Returns
-------
Expand Down Expand Up @@ -111,8 +111,8 @@ def array(source_array, ctx=None, dtype=None, aux_types=None):
The data type of the output array. The default dtype is ``source_array.dtype``
if `source_array` is an `NDArray`, `float32` otherwise.
aux_types: list of numpy.dtype, optional
An optional type for the aux data for the BaseSparseNDArray (default values
depends on the storage type)
An optional list of types of the aux data for RowSparseNDArray or CSRNDArray
(default values depend on the storage type)

Returns
-------
Expand Down
4 changes: 2 additions & 2 deletions src/operator/batch_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ __global__ void BatchNormalizationUpdateOutputKernel(
}

// Write normalized and update the output
const AccReal gamma = weight.numElements() > 0
const AccReal gamma = ((flags & FIX_GAMMA_FLAG) == 0 && weight.numElements() > 0)
? ScalarConvert<DType, AccReal>::to(weight[plane])
: ScalarConvert<int, AccReal>::to(1);
const AccReal beta = bias.numElements() > 0 ? ScalarConvert<DType, AccReal>::to(bias[plane])
Expand Down Expand Up @@ -332,7 +332,7 @@ static __global__ void BatchNormalizationBackwardKernel(
invstd = VARIANCE_TO_INVSTD(tensors.runningVar[plane], eps);
}

const AccReal weightVal = tensors.weight.numElements() > 0 ?
const AccReal weightVal = ((flags & FIX_GAMMA_FLAG) == 0 && tensors.weight.numElements() > 0) ?
ScalarConvert<DType, AccReal>::to(tensors.weight[plane]) : AccReal(1);
const AccReal norm = AccReal(1) / N;

Expand Down
60 changes: 40 additions & 20 deletions src/operator/random/sample_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,20 @@ namespace op {

// GPU versions of uniform and normal distribution.
template<>
void SampleUniform_<gpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
void SampleUniformDnsImpl<gpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const OpReqType& req,
TBlob* output) {
using namespace mxnet::op;
using namespace mshadow::expr;
typedef gpu xpu;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const SampleUniformParam& param = nnvm::get<SampleUniformParam>(attrs.parsed);
mshadow::Random<xpu, float> *prnd = ctx.requested[0].get_random<xpu, float>(s);
if (outputs[0].type_flag_ != mshadow::kFloat32) {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
if (output->type_flag_ != mshadow::kFloat32) {
MSHADOW_REAL_TYPE_SWITCH(output->type_flag_, DType, {
// Not float32: use workspace and copy to output
mshadow::Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
mshadow::Tensor<xpu, 2, DType> out = output->FlatTo2D<xpu, DType>(s);
mshadow::Tensor<xpu, 1, float> workspace =
ctx.requested[1].get_space_typed<xpu, 1, float>
(mshadow::Shape1(out.shape_.Size()), s);
Expand All @@ -51,27 +50,36 @@ void SampleUniform_<gpu>(const nnvm::NodeAttrs& attrs,
});
} else {
// float32: write directly into output
mshadow::Tensor<xpu, 2, float> out = outputs[0].FlatTo2D<xpu, float>(s);
mshadow::Tensor<xpu, 2, float> out = output->FlatTo2D<xpu, float>(s);
prnd->SampleUniform(&out, param.low, param.high);
}
}

template<>
void SampleNormal_<gpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
void SampleUniform_<gpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
TBlob out = outputs[0];
SampleUniformDnsImpl<gpu>(attrs, ctx, req[0], &out);
}

template<>
void SampleNormalDnsImpl<gpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const OpReqType& req,
TBlob* output) {
using namespace mxnet::op;
using namespace mshadow::expr;
typedef gpu xpu;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const SampleNormalParam& param = nnvm::get<SampleNormalParam>(attrs.parsed);
mshadow::Random<xpu, float> *prnd = ctx.requested[0].get_random<xpu, float>(s);
if (outputs[0].type_flag_ != mshadow::kFloat32) {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
if (output->type_flag_ != mshadow::kFloat32) {
MSHADOW_REAL_TYPE_SWITCH(output->type_flag_, DType, {
// Not float32: use workspace and copy to output
mshadow::Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
mshadow::Tensor<xpu, 2, DType> out = output->FlatTo2D<xpu, DType>(s);
mshadow::Tensor<xpu, 1, float> workspace =
ctx.requested[1].get_space_typed<xpu, 1, float>
(mshadow::Shape1(out.shape_.Size()), s);
Expand All @@ -80,16 +88,28 @@ void SampleNormal_<gpu>(const nnvm::NodeAttrs& attrs,
});
} else {
// float32: write directly into output
mshadow::Tensor<xpu, 2, float> out = outputs[0].FlatTo2D<xpu, float>(s);
mshadow::Tensor<xpu, 2, float> out = output->FlatTo2D<xpu, float>(s);
prnd->SampleGaussian(&out, param.loc, param.scale);
}
}

template<>
void SampleNormal_<gpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
TBlob out = outputs[0];
SampleNormalDnsImpl<gpu>(attrs, ctx, req[0], &out);
}

NNVM_REGISTER_OP(random_uniform)
.set_attr<FCompute>("FCompute<gpu>", SampleUniform_<gpu>);
.set_attr<FCompute>("FCompute<gpu>", SampleUniform_<gpu>)
.set_attr<FComputeEx>("FComputeEx<gpu>", SampleUniformEx_<gpu>);

NNVM_REGISTER_OP(random_normal)
.set_attr<FCompute>("FCompute<gpu>", SampleNormal_<gpu>);
.set_attr<FCompute>("FCompute<gpu>", SampleNormal_<gpu>)
.set_attr<FComputeEx>("FComputeEx<gpu>", SampleNormalEx_<gpu>);

} // namespace op
} // namespace mxnet
33 changes: 33 additions & 0 deletions src/operator/tensor/cast_storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,39 @@ namespace op {
DMLC_REGISTER_PARAMETER(CastStorageParam);
NNVM_REGISTER_OP(cast_storage)
.describe(R"code(Casts tensor storage type to the new type.

When an NDArray with default storage type is cast to csr or row_sparse storage,
the result is compact, which means:

- for csr, zero values will not be retained
- for row_sparse, row slices of all zeros will not be retained

The storage type of ``cast_storage`` output depends on stype parameter:

- cast_storage(csr, 'default') = default
- cast_storage(row_sparse, 'default') = default
- cast_storage(default, 'csr') = csr
- cast_storage(default, 'row_sparse') = row_sparse

Example::

dense = [[ 0., 1., 0.],
[ 2., 0., 3.],
[ 0., 0., 0.],
[ 0., 0., 0.]]

# cast to row_sparse storage type
rsp = cast_storage(default, 'default')
rsp.indices = [0, 1]
rsp.values = [[ 0., 1., 0.],
[ 2., 0., 3.]]

# cast to row_sparse storage type
csr = cast_storage(default, 'default')
csr.indices = [1, 0, 2]
csr.values = [ 1., 2., 3.]
csr.indptr = [0, 1, 3, 3, 3]

)code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
Expand Down
8 changes: 8 additions & 0 deletions src/operator/tensor/dot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ NNVM_REGISTER_OP(dot)
y = reshape([7,6,5,4,3,2,1,0], shape=(2,2,2))
dot(x,y)[0,0,1,1] = 0
sum(x[0,0,:]*y[:,1,1]) = 0

The storage type of ``dot`` output depends on storage types of inputs and transpose options:

- dot(csr, default) = default
- dot(csr.T, default) = row_sparse
- dot(csr, row_sparse) = default
- otherwise, ``dot`` generates output with default storage

)doc" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
Expand Down
9 changes: 8 additions & 1 deletion src/operator/tensor/elemwise_binary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@ namespace mxnet {
namespace op {
MXNET_OPERATOR_REGISTER_BINARY(elemwise_add)
.add_alias("_add").add_alias("_plus").add_alias("_Plus")
.describe("Adds arguments element-wise.")
.describe(R"code(Adds arguments element-wise.

The storage type of ``elemwise_add`` output depends on storage types of inputs

- elemwise_add(row_sparse, row_sparse) = row_sparse
- otherwise, ``elemwise_add`` generates output with default storage

)code")
.set_attr<FCompute>("FCompute<cpu>", BinaryCompute<cpu, mshadow::op::plus>)
.set_attr<nnvm::FGradient>("FGradient", CloneGradient{"_backward_add"})
.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryComputeEx<cpu, mshadow::op::plus>)
Expand Down
6 changes: 6 additions & 0 deletions src/operator/tensor/elemwise_sum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ NNVM_REGISTER_OP(add_n)
add\_n(a_1, a_2, ..., a_n) = a_1 + a_2 + ... + a_n

``add_n`` is potentially more efficient than calling ``add`` by `n` times.

The storage type of ``add_n`` output depends on storage types of inputs

- add_n(row_sparse, row_sparse, ..) = row_sparse
- otherwise, ``add_n`` generates output with default storage

)doc" ADD_FILELINE)
.set_attr_parser(ParamParser<ElementWiseSumParam>)
.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
Expand Down
5 changes: 5 additions & 0 deletions src/operator/tensor/sparse_retain.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ Example::
rsp_out.values = [[1, 2], [5, 6]]
rsp_out.indices = [0, 3]

The storage type of ``sparse_retain`` output depends on storage types of inputs

- sparse_retain(row_sparse, default) = row_sparse
- otherwise, ``sparse_retain`` is not supported

)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
Expand Down
35 changes: 17 additions & 18 deletions src/operator/tensor/square_sum-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,30 +196,26 @@ struct SquareSumRspGradKernel<req, 1> {
};

/*!
* This kernel assumes that the ograd and in_data
* are all rsp and have equal row_idx array.
* TODO(junwu): make the kernel general to support
* the cases when ograd and in_data have different
* row_idx arrays.
* Note: This kernel assumes that the ograd and in_data
* are all rsp and have equal row_idx array, or
* in_data is a full rsp.
*/
template<int req>
struct SquareSumRspGradKernel<req, 1, kRowSparseStorage> {
/*!
* \param i index of out_grad_row_idx
* \param i index of igrad.data()
* \param in_grad_row_idx row_idx of the gradient of the op's input
* \param in_grad gradient of the op's input
* \param out_grad_row_idx row_idx of the gradient of the op's output
* \param out_grad gradient of the op's output
* \param in_row_idx row idx of the op's input
* \param in_data op's input
*/
template<typename IType, typename DType>
MSHADOW_XINLINE static void Map(int i, IType* in_grad_row_idx, DType* in_grad,
const IType* out_grad_row_idx, const DType* out_grad,
const IType* in_row_idx, const DType* in_data,
const int64_t num_cols) {
const DType* in_data, const int64_t num_cols) {
const int64_t row = i / num_cols;
in_grad_row_idx[row] = in_row_idx[row];
in_grad_row_idx[row] = out_grad_row_idx[row];
KERNEL_ASSIGN(in_grad[i], req, 2*in_data[i]*out_grad[row]);
}
};
Expand Down Expand Up @@ -341,7 +337,7 @@ void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs,
const TBlob& igrad_data = igrad->data();
const TBlob igrad_row_idx = igrad->aux_data(rowsparse::kIdx);
const TBlob& ograd_data = ograd.data();
const TBlob in_data = input.data();
const TBlob& in_data = input.data();
const TBlob in_row_idx = input.aux_data(rowsparse::kIdx);
if (ograd.storage_type() == kDefaultStorage) {
if (0 == param.axis[0]) { // forward is sum per column
Expand Down Expand Up @@ -372,16 +368,20 @@ void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs,
" when ograd_stype = kRowSparseStorage";
CHECK_EQ(ograd.shape().ndim(), 2U);
const TBlob ograd_row_idx = ograd.aux_data(rowsparse::kIdx);
CHECK_EQ(ograd_row_idx.Size(), in_row_idx.Size());
CHECK(ograd_row_idx.Size() == in_row_idx.Size() || in_row_idx.Size() == in_data.shape_[0]);
MSHADOW_IDX_TYPE_SWITCH(igrad_row_idx.type_flag_, IType, {
if (std::is_same<xpu, cpu>::value) {
const IType* first1 = ograd_row_idx.dptr<IType>();
const IType* last1 = first1 + ograd_row_idx.Size();
const IType* first2 = in_row_idx.dptr<IType>();
CHECK(std::equal(first1, last1, first2)) << "SquareSumRspGradImpl only supports"
" equal ograd_row_idx and input_row_idx"
" when ograd and input are both"
" row-sparse";
// when ograd_row_idx and in_row_idx have the same size and input is not a full rsp
// ograd_row_idx and in_row_idx are expected to have the same elements
if (ograd_row_idx.Size() == in_row_idx.Size() && in_row_idx.Size() != in_data.shape_[0]) {
CHECK(std::equal(first1, last1, first2)) << "SquareSumRspGradImpl only supports"
" equal ograd_row_idx and input_row_idx"
" when ograd and input are both"
" row-sparse";
}
} else {
LOG(FATAL) << "SquareSumRspGradImpl has not implemented GPU version when"
" ograd and input are both row-sparse";
Expand All @@ -391,8 +391,7 @@ void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs,
Kernel<SquareSumRspGradKernel<req_type, 1, kRowSparseStorage>, xpu>::Launch(
s, igrad_data.Size(), igrad_row_idx.dptr<IType>(),
igrad_data.dptr<DType>(), ograd_row_idx.dptr<IType>(),
ograd_data.dptr<DType>(), in_row_idx.dptr<IType>(),
in_data.dptr<DType>(), num_cols);
ograd_data.dptr<DType>(), in_data.dptr<DType>(), num_cols);
})
})
})
Expand Down
3 changes: 1 addition & 2 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,6 @@ def check_batchnorm_training(stype):
rolling_mean = np.random.uniform(size=s)
rolling_std = np.random.uniform(size=s)

stype = 'row_sparse'
data = mx.symbol.Variable('data', stype=stype)
in_location = [mx.nd.array(data_tmp).tostype(stype), mx.nd.array(gamma).tostype(stype),
mx.nd.array(beta).tostype(stype)]
Expand Down Expand Up @@ -935,7 +934,7 @@ def check_batchnorm_training(stype):
test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True, axis=chaxis)
check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)

stypes = ['row_sparse', 'csr', 'default']
stypes = ['row_sparse', 'default']
for stype in stypes:
check_batchnorm_training(stype)

Expand Down
Loading