Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-404] elemwise_add/sub between rsp and rsp on GPU #11179

Merged
merged 2 commits into from
Jun 20, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Support for elemwise_add/sub between rsp and rsp on GPU
  • Loading branch information
Hao Jin committed Jun 11, 2018
commit d6d281ce35232530cba09b45e56d4f3f4e172408
2 changes: 1 addition & 1 deletion include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class NDArray {
}

/* \brief Check whether the two arrays are the same array */
inline bool IsSame(const NDArray& other) {
inline bool IsSame(const NDArray& other) const {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piiswrong I made the change here so that I can also call this function when I have a const NDArray object.

return ptr_ == other.ptr_ &&
shape_ == other.shape_ &&
byte_offset_ == other.byte_offset_ &&
Expand Down
18 changes: 1 addition & 17 deletions src/operator/tensor/elemwise_binary_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,6 @@
namespace mxnet {
namespace op {

template<typename OP>
void ElemwiseBinaryOp::RspRspOp(mshadow::Stream<gpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
const OpReqType req,
const NDArray &output,
const bool lhs_may_be_dense,
const bool rhs_may_be_dense,
const bool allow_inplace,
const bool scatter) {
LOG(FATAL) << "GPU not supported for RspRspOp";
}


/*! \brief binary op handling for the following row sparse inputs/outputs
rsp, rsp -> rsp,
dns, rsp -> rsp,
Expand Down Expand Up @@ -622,7 +606,7 @@ void ElemwiseBinaryOp::DnsRspDnsOp(mshadow::Stream<xpu> *s,
const bool reverse) {
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(dns.storage_type(), kDefaultStorage);
CHECK(dns.storage_type() == kDefaultStorage || dns.storage_type() == kRowSparseStorage);
CHECK_EQ(rsp.storage_type(), kRowSparseStorage);
CHECK_EQ(output.data().Size(), dns.data().Size());
CHECK(req != kAddTo);
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/elemwise_binary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ class ElemwiseBinaryOp : public OpBase {
if (!dispatched && rsp && ContainsOnlyStorage(*in_attrs, kRowSparseStorage)) {
// rsp, rsp, ... -> rsp
dispatched = storage_type_assign(out_attrs, kRowSparseStorage,
dispatch_mode, dispatch_ex);
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched && csr && ContainsOnlyStorage(*in_attrs, kCSRStorage)) {
// csr, csr, ... -> csr
Expand Down
134 changes: 134 additions & 0 deletions src/operator/tensor/elemwise_binary_op_basic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,146 @@
* \file elemwise_binary_scalar_op.cu
* \brief GPU Implementation of unary function.
*/
#include <cub/cub.cuh>
#include "./elemwise_binary_op.h"
#include "./elemwise_binary_op-inl.h"

namespace mxnet {
namespace op {

template<typename OP>
struct RspElemwiseKernel {
template<typename DType, typename IType>
static MSHADOW_XINLINE void Map(int i, DType* out, const IType* lookup_table,
const DType* data, const IType* indices,
const nnvm::dim_t nz_rows, const nnvm::dim_t num_cols) {
if (i < nz_rows * num_cols) {
const nnvm::dim_t row = i / num_cols;
const nnvm::dim_t col = i % num_cols;
const nnvm::dim_t out_row = lookup_table[indices[row]] - 1;
const nnvm::dim_t out_idx = out_row * num_cols + col;
out[out_idx] = OP::Map(out[out_idx], data[i]);
}
}
};

template<typename OP>
void ElemwiseBinaryOp::RspRspOp(mshadow::Stream<gpu> *s,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have unit test for write inplace?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In-place case shares the same code as in-place case between dns and rsp, which already has a unit test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW correctness is double-checked in benchmark script during the warmup.

const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
const OpReqType req,
const NDArray &output,
const bool lhs_may_be_dense,
const bool rhs_may_be_dense,
const bool allow_inplace,
const bool scatter) {
using namespace mshadow;
using namespace mxnet_op;
using namespace mshadow::expr;
using namespace rowsparse;

if (req == kNullOp) return;

CHECK(!scatter) << "scatter is not supported in RspRspOp on GPU yet...";
CHECK(lhs.storage_type() == kRowSparseStorage && rhs.storage_type() == kRowSparseStorage);
CHECK(output.storage_type() == kRowSparseStorage);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it support kAddTo? CHECK_NE(kAddTo)?

CHECK(req != kAddTo);

const nnvm::dim_t num_rows = output.shape()[0];
MSHADOW_TYPE_SWITCH(lhs.data().type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(lhs.aux_data(kIdx).type_flag_, IType, {
if (lhs.storage_initialized() && rhs.storage_initialized()) {
const nnvm::dim_t lhs_nz_rows = lhs.storage_shape()[0];
const nnvm::dim_t rhs_nz_rows = rhs.storage_shape()[0];
const nnvm::dim_t num_cols = lhs.data().Size() / lhs_nz_rows;
// Optimize for the case where one of the rsps is actually dense
if ((lhs_nz_rows == num_rows || rhs_nz_rows == num_rows) && req == kWriteInplace) {
const NDArray& dns = (output.IsSame(lhs)) ? lhs : rhs;
const NDArray& rsp = (output.IsSame(lhs)) ? rhs : lhs;
const bool reverse = !(lhs_nz_rows == num_rows);
ElemwiseBinaryOp::DnsRspDnsOp<gpu, OP>(s, attrs, ctx, dns, rsp, req, output, reverse);
return;
}
CHECK(req == kWriteTo) << "Should be kWriteTo but got " << req;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this function assumes req is never kNullOp, better document it in the header.

const TBlob& lhs_indices = lhs.aux_data(kIdx);
const TBlob& rhs_indices = rhs.aux_data(kIdx);
size_t common_row_table_bytes = num_rows * sizeof(IType);
IType* common_row_table = NULL;
void* temp_storage_ptr = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceScan::InclusiveSum(temp_storage_ptr,
temp_storage_bytes,
common_row_table,
common_row_table,
num_rows,
mshadow::Stream<gpu>::GetStream(s));
size_t workspace_bytes = common_row_table_bytes + temp_storage_bytes;
Tensor<gpu, 1, char> workspace =
ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(workspace_bytes), s);
common_row_table = reinterpret_cast<IType*>(workspace.dptr_);
temp_storage_ptr = workspace.dptr_ + common_row_table_bytes;
mxnet_op::Kernel<set_zero, gpu>::Launch(s, num_rows, common_row_table);
Kernel<MarkRspRowFlgKernel, gpu>::Launch(
s, lhs_nz_rows, common_row_table, lhs_indices.dptr<IType>(), lhs_nz_rows);
Kernel<MarkRspRowFlgKernel, gpu>::Launch(
s, rhs_nz_rows, common_row_table, rhs_indices.dptr<IType>(), rhs_nz_rows);
cub::DeviceScan::InclusiveSum(temp_storage_ptr,
temp_storage_bytes,
common_row_table,
common_row_table,
num_rows,
mshadow::Stream<gpu>::GetStream(s));
nnvm::dim_t nnr_out = 0;
CUDA_CALL(cudaMemcpy(&nnr_out, &common_row_table[num_rows-1], sizeof(nnvm::dim_t),
cudaMemcpyDeviceToHost));
output.CheckAndAlloc({mshadow::Shape1(nnr_out)});
Kernel<FillRspRowIdxKernel, gpu>::Launch(
s, num_rows, output.aux_data(kIdx).dptr<IType>(), common_row_table, num_rows);
Kernel<set_zero, gpu>::Launch(s, nnr_out * num_cols, output.data().dptr<DType>());
Kernel<RspElemwiseKernel<mshadow_op::plus>, gpu>::Launch(
s, lhs_nz_rows * num_cols, output.data().dptr<DType>(), common_row_table,
lhs.data().dptr<DType>(), lhs_indices.dptr<IType>(), lhs_nz_rows, num_cols);
Kernel<RspElemwiseKernel<OP>, gpu>::Launch(
s, rhs_nz_rows * num_cols, output.data().dptr<DType>(), common_row_table,
rhs.data().dptr<DType>(), rhs_indices.dptr<IType>(), rhs_nz_rows, num_cols);
} else {
if (lhs.storage_initialized()) {
if (req == kWriteTo) {
output.CheckAndAlloc({lhs.aux_shape(kIdx)});
Copy(output.data().FlatTo1D<gpu, DType>(),
lhs.data().FlatTo1D<gpu, DType>(), s);
Copy(output.aux_data(kIdx).FlatTo1D<gpu, IType>(),
lhs.aux_data(kIdx).FlatTo1D<gpu, IType>(), s);
} else if (req == kWriteInplace && rhs.IsSame(output)) {
LOG(FATAL) << "Inplace on an empty rhs is not supported";
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about kWriteInplace in all these branches? should we add a check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra checks and tests added.

} else if (rhs.storage_initialized()) {
if (req == kWriteTo) {
output.CheckAndAlloc({rhs.aux_shape(kIdx)});
} else if (req == kWriteInplace && lhs.IsSame(output)) {
LOG(FATAL) << "Inplace on an empty lhs is not supported";
}
if (std::is_same<OP, mshadow_op::minus>::value) {
Kernel<op_with_req<mshadow_op::negation, kWriteTo>, gpu>::Launch(
s, rhs.data().Size(), output.data().dptr<DType>(), rhs.data().dptr<DType>());
} else if (req == kWriteTo) {
Copy(output.data().FlatTo1D<gpu, DType>(),
rhs.data().FlatTo1D<gpu, DType>(), s);
}
if (req == kWriteTo) {
Copy(output.aux_data(kIdx).FlatTo1D<gpu, IType>(),
rhs.aux_data(kIdx).FlatTo1D<gpu, IType>(), s);
}
} else {
FillZerosRspImpl(s, output);
}
}
});
});
}

NNVM_REGISTER_OP(elemwise_add)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, op::mshadow_op::plus>)
.set_attr<FComputeEx>("FComputeEx<gpu>", ElemwiseBinaryOp::ComputeEx<gpu, op::mshadow_op::plus>);
Expand Down