Skip to content

Commit

Permalink
add norm operator for sparse ndarray (apache#9479)
Browse files Browse the repository at this point in the history
* add norm(row_sparse)

* add doc

* add example in doc

* address commnts

* rebase

* add knullop check
  • Loading branch information
eric-haibin-lin authored and szha committed Jan 21, 2018
1 parent caa6250 commit 4616737
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 10 deletions.
15 changes: 13 additions & 2 deletions docs/api/python/ndarray/sparse.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ We summarize the interface for each class in the following sections.
CSRNDArray.sum
CSRNDArray.mean
CSRNDArray.norm
```

### Powers
Expand Down Expand Up @@ -237,6 +238,15 @@ We summarize the interface for each class in the following sections.
RowSparseNDArray.zeros_like
```

### Array reduction

```eval_rst
.. autosummary::
:nosignatures:
RowSparseNDArray.norm
```

### Array rounding

```eval_rst
Expand Down Expand Up @@ -414,6 +424,7 @@ We summarize the interface for each class in the following sections.
sum
mean
norm
```

### Rounding
Expand Down Expand Up @@ -492,10 +503,10 @@ We summarize the interface for each class in the following sections.
```eval_rst
.. autoclass:: mxnet.ndarray.sparse.CSRNDArray
:members: shape, context, dtype, stype, data, indices, indptr, copy, copyto, as_in_context, asscipy, asnumpy, asscalar, astype, tostype, slice, wait_to_read, zeros_like, __neg__, sum, mean, square, __getitem__, __setitem__, check_format
:members: shape, context, dtype, stype, data, indices, indptr, copy, copyto, as_in_context, asscipy, asnumpy, asscalar, astype, tostype, slice, wait_to_read, zeros_like, __neg__, sum, mean, norm, square, __getitem__, __setitem__, check_format
.. autoclass:: mxnet.ndarray.sparse.RowSparseNDArray
:members: shape, context, dtype, stype, data, indices, copy, copyto, as_in_context, asnumpy, asscalar, astype, tostype, wait_to_read, zeros_like, round, rint, fix, floor, ceil, trunc, sin, tan, arcsin, arctan, degrees, radians, sinh, tanh, arcsinh, arctanh, expm1, log1p, sqrt, square, __negative__, __getitem__, __setitem__, check_format, retain, clip, sign
:members: shape, context, dtype, stype, data, indices, copy, copyto, as_in_context, asnumpy, asscalar, astype, tostype, wait_to_read, zeros_like, round, rint, fix, floor, ceil, trunc, sin, tan, arcsin, arctan, degrees, radians, sinh, tanh, arcsinh, arctanh, expm1, log1p, sqrt, square, __negative__, norm, __getitem__, __setitem__, check_format, retain, clip, sign
.. automodule:: mxnet.ndarray.sparse
:members:
Expand Down
89 changes: 82 additions & 7 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -816,20 +816,95 @@ struct ReduceGrad {
}
};

inline bool L2NormStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
const int in_stype = in_attrs->at(0);
int& out_stype = out_attrs->at(0);
bool dispatched = false;
if (!dispatched && in_stype == kDefaultStorage) {
// dns -> dns
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
DispatchMode::kFCompute);
}
if (!dispatched && (in_stype == kCSRStorage || in_stype == kRowSparseStorage)) {
// csr/rsp -> dns
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
DispatchMode::kFComputeEx);
}
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
return dispatched;
}

template<typename xpu>
void L2NormComputeImpl(mshadow::Stream<xpu> *s,
const TBlob& input,
const OpReqType req,
const TBlob& output) {
if (req == kNullOp) return;
MSHADOW_REAL_TYPE_SWITCH(output.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req, Req, {
mshadow::Tensor<xpu, 1, DType> out = output.get<xpu, 1, DType>(s);
mshadow::Tensor<xpu, 1, DType> in = input.get_with_shape<xpu, 1, DType>(
mshadow::Shape1(input.shape_.Size()), s);
mshadow::VectorDot(out, in, in);
DType* out_data = output.dptr<DType>();
using namespace mxnet_op;
Kernel<op_with_req<mshadow_op::square_root, Req>, xpu>::Launch(
s, output.Size(), out_data, out_data);
});
});
}


template<typename xpu>
void L2NormCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
mshadow::Tensor<xpu, 1, DType> out = outputs[0].get<xpu, 1, DType>(s);
mshadow::Tensor<xpu, 1, DType> in = inputs[0].get_with_shape<xpu, 1, DType>(
mshadow::Shape1(inputs[0].shape_.Size()), s);
mshadow::VectorDot(out, in, in);
ASSIGN_DISPATCH(out, req[0], mshadow::expr::F<mxnet::op::mshadow_op::square_root>(out));
});
L2NormComputeImpl(s, inputs[0], req[0], outputs[0]);
}

template<typename xpu>
void L2NormComputeSparseImpl(mshadow::Stream<xpu> *s,
const NDArray& input,
const OpReqType req,
const TBlob& output) {
if (req == kNullOp) return;
// input is zeros
if (!input.storage_initialized()) {
// Add zeros. No op.
if (req == kAddTo) return;
Fill<false>(s, output, req, 0);
} else {
L2NormComputeImpl(s, input.data(), req, output);
}
}

template<typename xpu>
void L2NormComputeEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
const NDArrayStorageType in_stype = inputs[0].storage_type();
if (in_stype == kCSRStorage || in_stype == kRowSparseStorage) {
L2NormComputeSparseImpl(s, inputs[0], req[0], outputs[0].data());
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
}

/*! \brief index element from array along axes */
Expand Down
11 changes: 11 additions & 0 deletions src/operator/tensor/broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ NNVM_REGISTER_OP(_broadcast_backward)
});

NNVM_REGISTER_OP(norm)
MXNET_ADD_SPARSE_OP_ALIAS(norm)
.describe(R"code(Flattens the input array and then computes the l2 norm.
Examples::
Expand All @@ -254,6 +255,14 @@ Examples::
norm(x) = [5.47722578]
rsp = x.cast_storage('row_sparse')
norm(rsp) = [5.47722578]
csr = x.cast_storage('csr')
norm(csr) = [5.47722578]
)code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
Expand All @@ -268,7 +277,9 @@ Examples::
return true;
})
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferStorageType>("FInferStorageType", L2NormStorageType)
.set_attr<FCompute>("FCompute<cpu>", L2NormCompute<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", L2NormComputeEx<cpu>)
.add_argument("data", "NDArray-or-Symbol", "Source input");

} // namespace op
Expand Down
3 changes: 2 additions & 1 deletion src/operator/tensor/broadcast_reduce_op_value.cu
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ NNVM_REGISTER_OP(_broadcast_backward)
.set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow::red::sum>);

NNVM_REGISTER_OP(norm)
.set_attr<FCompute>("FCompute<gpu>", L2NormCompute<gpu>);
.set_attr<FCompute>("FCompute<gpu>", L2NormCompute<gpu>)
.set_attr<FComputeEx>("FComputeEx<gpu>", L2NormComputeEx<gpu>);

} // namespace op
} // namespace mxnet
13 changes: 13 additions & 0 deletions tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,19 @@ def test_sparse_nd_check_format():
a = mx.nd.sparse.row_sparse_array((data_list, indices_list), shape=shape)
assertRaises(mx.base.MXNetError, a.check_format)

def test_sparse_nd_norm():
def check_sparse_nd_norm(stype, shape, density):
data, _ = rand_sparse_ndarray(shape, stype, density)
norm = data.norm()
expected_norm = np.linalg.norm(data.asnumpy())
assert_almost_equal(norm.asnumpy(), expected_norm)

shape = (5, 5)
stypes = ['row_sparse', 'csr']
densities = [0, 0.5]
for stype in stypes:
for density in densities:
check_sparse_nd_norm(stype, shape, density)

if __name__ == '__main__':
import nose
Expand Down

0 comments on commit 4616737

Please sign in to comment.