Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
support elemwise_mul between dns and csr
Browse files Browse the repository at this point in the history
  • Loading branch information
Hao Jin committed May 19, 2018
1 parent bea5fd1 commit 98b0d93
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 27 deletions.
87 changes: 87 additions & 0 deletions src/operator/tensor/elemwise_binary_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,93 @@ void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream<xpu> *s,
});
}

/*!
* \brief Kernel for performing elemwise op between dense and csr matrix
* \param i global thread id
* \param req type of request
* \param out output array
* \param dns_data data array of dense input
* \param csr_data data array of csr input
* \param csr_indices indices array of csr input
* \param csr_indptr indptr array of csr input
* \param num_rows number of rows of both inputs
* \param num_cols number of columns of both inputs
*/
template<int req, typename OP, bool reverse>
struct ElemwiseDnsCsrCsrKernel {
template<typename DType, typename IType, typename CType>
MSHADOW_XINLINE static void Map(int i, DType* out, DType* dns_data,
const DType* csr_data, const IType* csr_indices,
const CType* csr_indptr, const nnvm::dim_t num_rows,
const nnvm::dim_t num_cols) {
if (i < num_rows) {
for (int j = csr_indptr[i]; j < csr_indptr[i+1]; ++j) {
KERNEL_ASSIGN(out[j], req, reverse ?
OP::Map(dns_data[i * num_cols + csr_indices[j]], csr_data[j]) :
OP::Map(csr_data[j], dns_data[i * num_cols + csr_indices[j]]));
}
}
}
};

/*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
template<typename xpu, typename OP>
void ElemwiseBinaryOp::DnsCsrCsrOp(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &dns,
const NDArray &csr,
const OpReqType req,
const NDArray &output,
const bool reverse) {
using namespace mshadow;
using namespace mxnet_op;
using namespace csr;
CHECK_EQ(dns.storage_type(), kDefaultStorage);
CHECK_EQ(csr.storage_type(), kCSRStorage);
CHECK_EQ(req, kWriteTo) << "elemwise(dns, csr) = csr only supports kWriteTo";
CHECK(req != kNullOp);
const bool supported_op = std::is_same<OP, mshadow_op::mul>::value ||
std::is_same<OP, mshadow_op::div>::value;
CHECK(supported_op == true) << "elemwise(dns, csr) = csr only supports mul/div";
const nnvm::dim_t num_csr_rows = csr.shape()[0];
const nnvm::dim_t num_csr_cols = csr.shape()[1];
const nnvm::dim_t nnz = csr.storage_shape()[0];
Stream<xpu> *s = ctx.get_stream<xpu>();

output.CheckAndAlloc({Shape1(num_csr_rows + 1), Shape1(nnz)});
if (csr.storage_initialized()) {
TBlob csr_data = csr.data();
TBlob csr_indices = csr.aux_data(kIdx);
TBlob csr_indptr = csr.aux_data(kIndPtr);
MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, {
MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, {
MXNET_ASSIGN_REQ_SWITCH(req, Req, {
if (reverse) {
Kernel<ElemwiseDnsCsrCsrKernel<Req, OP, true>, xpu>::Launch(
s, num_csr_rows, output.data().dptr<DType>(), dns.data().dptr<DType>(),
csr_data.dptr<DType>(), csr_indices.dptr<IType>(), csr_indptr.dptr<CType>(),
num_csr_rows, num_csr_cols);
} else {
Kernel<ElemwiseDnsCsrCsrKernel<Req, OP, false>, xpu>::Launch(
s, num_csr_rows, output.data().dptr<DType>(), dns.data().dptr<DType>(),
csr_data.dptr<DType>(), csr_indices.dptr<IType>(), csr_indptr.dptr<CType>(),
num_csr_rows, num_csr_cols);
}
Copy(output.aux_data(kIdx).FlatTo1D<xpu, IType>(),
csr.aux_data(kIdx).FlatTo1D<xpu, IType>(), s);
Copy(output.aux_data(kIndPtr).FlatTo1D<xpu, CType>(),
csr.aux_data(kIndPtr).FlatTo1D<xpu, CType>(), s);
});
});
});
});
} else {
FillZerosCsrImpl(s, output);
}
}


/*!
* \brief Kernel for performing elemwise op between dense and rsp tensor
* \param i global thread id
Expand Down
78 changes: 52 additions & 26 deletions src/operator/tensor/elemwise_binary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,44 +232,54 @@ class ElemwiseBinaryOp : public OpBase {
/*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
template<typename OP>
static void CsrCsrOp(mshadow::Stream<cpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output);
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output);

/*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
template<typename OP>
static void CsrCsrOp(mshadow::Stream<gpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output);
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output);

/*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
template<typename xpu, typename OP>
static void DnsCsrDnsOp(mshadow::Stream<xpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output,
const bool reverse);
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output,
const bool reverse);

/*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
template<typename xpu, typename OP>
static void DnsCsrCsrOp(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output,
const bool reverse);

/*! \brief DNS -op- RSP binary operator for non-canonical NDArray */
template<typename xpu, typename OP>
static void DnsRspDnsOp(mshadow::Stream<xpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output,
const bool reverse);
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output,
const bool reverse);

public:
/*!
Expand Down Expand Up @@ -336,6 +346,14 @@ class ElemwiseBinaryOp : public OpBase {
dispatched = storage_type_assign(&out_stype, kRowSparseStorage,
dispatch_mode, dispatch_ex);
}
if (!dispatched &&
((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
(lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage))) {
// csr, dns -> csr
// dns, csr -> csr
dispatched = storage_type_assign(&out_stype, kCSRStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
Expand Down Expand Up @@ -540,6 +558,14 @@ class ElemwiseBinaryOp : public OpBase {
req[0], outputs[0], lhs_may_be_dense, rhs_may_be_dense, false, false);
} else if (lhs_stype == kCSRStorage && rhs_stype == kCSRStorage) {
ComputeEx<xpu, OP>(attrs, ctx, inputs, req, outputs);
} else if (((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
(lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage)) &&
out_stype == kCSRStorage) {
const NDArray& dns = (lhs_stype == kDefaultStorage)? inputs[0] : inputs[1];
const NDArray& csr = (lhs_stype == kCSRStorage)? inputs[0] : inputs[1];
const bool reverse = (lhs_stype == kCSRStorage);

DnsCsrCsrOp<xpu, OP>(attrs, ctx, dns, csr, req[0], outputs[0], reverse);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
Expand Down
4 changes: 3 additions & 1 deletion src/operator/tensor/elemwise_binary_op_basic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ NNVM_REGISTER_OP(_backward_sub)
mshadow_op::negation>);

NNVM_REGISTER_OP(elemwise_mul)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, op::mshadow_op::mul>);
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, op::mshadow_op::mul>)
.set_attr<FComputeEx>("FComputeEx<gpu>",
ElemwiseBinaryOp::ComputeDnsLRValueEx<gpu, op::mshadow_op::mul, true, true>);

NNVM_REGISTER_OP(_backward_mul)
.set_attr<FCompute>("FCompute<gpu>",
Expand Down

0 comments on commit 98b0d93

Please sign in to comment.