Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add support for initilazer with rowsparse output
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin committed Aug 9, 2017
1 parent 253ae57 commit 8da42c2
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 16 deletions.
9 changes: 6 additions & 3 deletions src/operator/random/sample_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ Example::
[ 0.54488319, 0.84725171]]
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", SampleUniform_<cpu>);
.set_attr<FCompute>("FCompute<cpu>", SampleUniform_<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", SampleUniformEx_<cpu>);

// Add "normal" alias for backward compatibility
MXNET_OPERATOR_REGISTER_SAMPLE(random_normal, SampleNormalParam)
Expand All @@ -78,7 +79,8 @@ Example::
random_normal(loc=0, scale=1, shape=(2,2)) = [[ 1.89171135, -1.16881478],
[-1.23474145, 1.55807114]]
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", SampleNormal_<cpu>);
.set_attr<FCompute>("FCompute<cpu>", SampleNormal_<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", SampleNormalEx_<cpu>);

MXNET_OPERATOR_REGISTER_SAMPLE(random_gamma, SampleGammaParam)
.add_alias("_sample_gamma")
Expand All @@ -91,7 +93,8 @@ Example::
random_gamma(alpha=9, beta=0.5, shape=(2,2)) = [[ 7.10486984, 3.37695289],
[ 3.91697288, 3.65933681]]
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", SampleGamma_<cpu>);
.set_attr<FCompute>("FCompute<cpu>", SampleGamma_<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", SampleGammaEx_<cpu>);

MXNET_OPERATOR_REGISTER_SAMPLE(random_exponential, SampleExponentialParam)
.add_alias("_sample_exponential")
Expand Down
109 changes: 96 additions & 13 deletions src/operator/random/sample_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,29 +232,75 @@ struct SampleGenNegBinomialParam : public dmlc::Parameter<SampleGenNegBinomialPa
}
};

using FSampleCompute = std::function<void (const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const OpReqType& req,
TBlob* outputs)>;

template<typename xpu>
void SampleUniform_(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
void SampleComputeEx_(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs,
FSampleCompute fcomp) {
NDArray output = outputs[0];
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
if (output.storage_type() == kRowSparseStorage) {
// indices
nnvm::dim_t nnr = output.shape()[0];
output.CheckAndAlloc({mshadow::Shape1(nnr)});
PopulateFullIdxRspImpl(s, &output);
// data
TBlob out_blob = output.data();
fcomp(attrs, ctx, req[0], &out_blob);
} else {
LOG(FATAL) << "Unexpected storage type for SampleComputeEx_: "
<< output.storage_type();
}
}

template<typename xpu>
void SampleUniformDnsImpl(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const OpReqType& req,
TBlob* output) {
using namespace mxnet::op;
using namespace mshadow::expr;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const SampleUniformParam& param = nnvm::get<SampleUniformParam>(attrs.parsed);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MSHADOW_REAL_TYPE_SWITCH(output->type_flag_, DType, {
mshadow::Random<xpu, DType> *prnd = ctx.requested[0].get_random<xpu, DType>(s);
mshadow::Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
mshadow::Tensor<xpu, 2, DType> out = output->FlatTo2D<xpu, DType>(s);
prnd->SampleUniform(&out, param.low, param.high);
});
}

template<typename xpu>
void SampleNormal_(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
void SampleUniform_(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
TBlob out = outputs[0];
SampleUniformDnsImpl<xpu>(attrs, ctx, req[0], &out);
}


template<typename xpu>
void SampleUniformEx_(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
SampleComputeEx_<xpu>(attrs, ctx, inputs, req, outputs, SampleUniformDnsImpl<xpu>);
}

template<typename xpu>
void SampleNormalDnsImpl(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const OpReqType& req,
TBlob* outputs) {
using namespace mxnet::op;
using namespace mshadow::expr;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
Expand All @@ -268,11 +314,29 @@ void SampleNormal_(const nnvm::NodeAttrs& attrs,
}

template<typename xpu>
void SampleGamma_(const nnvm::NodeAttrs& attrs,
void SampleNormal_(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
TBlob out = outputs[0];
SampleNormalDnsImpl<xpu>(attrs, ctx, req[0], &out);
}

template<typename xpu>
void SampleNormalEx_(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
SampleComputeEx_<xpu>(attrs, ctx, inputs, req, outputs, SampleNormalDnsImpl<xpu>);
}

template<typename xpu>
void SampleGammaDnsImpl(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const OpReqType& req,
TBlob* outputs) {
using namespace mxnet::op;
using namespace mshadow::expr;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
Expand All @@ -286,6 +350,25 @@ void SampleGamma_(const nnvm::NodeAttrs& attrs,
});
}

template<typename xpu>
void SampleGamma_(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
TBlob out = outputs[0];
SampleGammaDnsImpl<xpu>(attrs, ctx, req[0], &out);
}

template<typename xpu>
void SampleGammaEx_(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
SampleComputeEx_<xpu>(attrs, ctx, inputs, req, outputs, SampleGammaDnsImpl<xpu>);
}

template<typename xpu>
void SampleExponential_(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down
20 changes: 20 additions & 0 deletions src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,26 @@ inline void FillDnsZerosRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
});
}

struct PopulateFullIdxRspKernel {
template<typename IType>
MSHADOW_XINLINE static void Map(int i, IType* out) {
KERNEL_ASSIGN(out[i], kWriteTo, i);
}
};

// Fill full indices NDArray with zeros by updating the aux shape.
template<typename xpu>
void PopulateFullIdxRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
using namespace rowsparse;
CHECK_EQ(dst->storage_type(), kRowSparseStorage);
nnvm::dim_t nnr = dst->shape()[0];
dst->CheckAndAllocAuxData(kIdx, mshadow::Shape1(nnr));
MSHADOW_IDX_TYPE_SWITCH(dst->aux_type(kIdx), IType, {
IType* idx = dst->aux_data(kIdx).dptr<IType>();
mxnet_op::Kernel<PopulateFullIdxRspKernel, xpu>::Launch(s, nnr, idx);
});
}

// Fill a rsp NDArray with zeros by updating the aux shape.
template<typename xpu>
void FillZerosRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,20 @@ def test_sparse_nd_output_fallback():
mx.nd.random_normal(shape=shape, out=out)
assert(np.sum(out.asnumpy()) != 0)

def test_sparse_nd_random():
shape = (100, 100)
fns = [mx.nd.random_uniform, mx.nd.random_normal, mx.nd.random_gamma]
for fn in fns:
rsp_out = mx.nd.zeros(shape=shape, stype='row_sparse')
dns_out = mx.nd.zeros(shape=shape, stype='default')
mx.random.seed(0)
np.random.seed(0)
fn(shape=shape, out=dns_out)
mx.random.seed(0)
np.random.seed(0)
fn(shape=shape, out=rsp_out)
assert_almost_equal(dns_out.asnumpy(), rsp_out.asnumpy())


def test_sparse_nd_astype():
stypes = ['row_sparse', 'csr']
Expand Down

0 comments on commit 8da42c2

Please sign in to comment.