-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-404] elemwise_add/sub between rsp and rsp on GPU #11179
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,12 +22,146 @@ | |
* \file elemwise_binary_scalar_op.cu | ||
* \brief GPU Implementation of unary function. | ||
*/ | ||
#include <cub/cub.cuh> | ||
#include "./elemwise_binary_op.h" | ||
#include "./elemwise_binary_op-inl.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
template<typename OP> | ||
struct RspElemwiseKernel { | ||
template<typename DType, typename IType> | ||
static MSHADOW_XINLINE void Map(int i, DType* out, const IType* lookup_table, | ||
const DType* data, const IType* indices, | ||
const nnvm::dim_t nz_rows, const nnvm::dim_t num_cols) { | ||
if (i < nz_rows * num_cols) { | ||
const nnvm::dim_t row = i / num_cols; | ||
const nnvm::dim_t col = i % num_cols; | ||
const nnvm::dim_t out_row = lookup_table[indices[row]] - 1; | ||
const nnvm::dim_t out_idx = out_row * num_cols + col; | ||
out[out_idx] = OP::Map(out[out_idx], data[i]); | ||
} | ||
} | ||
}; | ||
|
||
template<typename OP> | ||
void ElemwiseBinaryOp::RspRspOp(mshadow::Stream<gpu> *s, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we have unit test for write inplace? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In-place case shares the same code as in-place case between dns and rsp, which already has a unit test. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW correctness is double-checked in benchmark script during the warmup. |
||
const nnvm::NodeAttrs &attrs, | ||
const OpContext &ctx, | ||
const NDArray &lhs, | ||
const NDArray &rhs, | ||
const OpReqType req, | ||
const NDArray &output, | ||
const bool lhs_may_be_dense, | ||
const bool rhs_may_be_dense, | ||
const bool allow_inplace, | ||
const bool scatter) { | ||
using namespace mshadow; | ||
using namespace mxnet_op; | ||
using namespace mshadow::expr; | ||
using namespace rowsparse; | ||
|
||
if (req == kNullOp) return; | ||
|
||
CHECK(!scatter) << "scatter is not supported in RspRspOp on GPU yet..."; | ||
CHECK(lhs.storage_type() == kRowSparseStorage && rhs.storage_type() == kRowSparseStorage); | ||
CHECK(output.storage_type() == kRowSparseStorage); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it support kAddTo? CHECK_NE(kAddTo)? |
||
CHECK(req != kAddTo); | ||
|
||
const nnvm::dim_t num_rows = output.shape()[0]; | ||
MSHADOW_TYPE_SWITCH(lhs.data().type_flag_, DType, { | ||
MSHADOW_IDX_TYPE_SWITCH(lhs.aux_data(kIdx).type_flag_, IType, { | ||
if (lhs.storage_initialized() && rhs.storage_initialized()) { | ||
const nnvm::dim_t lhs_nz_rows = lhs.storage_shape()[0]; | ||
const nnvm::dim_t rhs_nz_rows = rhs.storage_shape()[0]; | ||
const nnvm::dim_t num_cols = lhs.data().Size() / lhs_nz_rows; | ||
// Optimize for the case where one of the rsps is actually dense | ||
if ((lhs_nz_rows == num_rows || rhs_nz_rows == num_rows) && req == kWriteInplace) { | ||
const NDArray& dns = (output.IsSame(lhs)) ? lhs : rhs; | ||
const NDArray& rsp = (output.IsSame(lhs)) ? rhs : lhs; | ||
const bool reverse = !(lhs_nz_rows == num_rows); | ||
ElemwiseBinaryOp::DnsRspDnsOp<gpu, OP>(s, attrs, ctx, dns, rsp, req, output, reverse); | ||
return; | ||
} | ||
CHECK(req == kWriteTo) << "Should be kWriteTo but got " << req; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this function assumes req is never kNullOp, better document it in the header. |
||
const TBlob& lhs_indices = lhs.aux_data(kIdx); | ||
const TBlob& rhs_indices = rhs.aux_data(kIdx); | ||
size_t common_row_table_bytes = num_rows * sizeof(IType); | ||
IType* common_row_table = NULL; | ||
void* temp_storage_ptr = NULL; | ||
size_t temp_storage_bytes = 0; | ||
cub::DeviceScan::InclusiveSum(temp_storage_ptr, | ||
temp_storage_bytes, | ||
common_row_table, | ||
common_row_table, | ||
num_rows, | ||
mshadow::Stream<gpu>::GetStream(s)); | ||
size_t workspace_bytes = common_row_table_bytes + temp_storage_bytes; | ||
Tensor<gpu, 1, char> workspace = | ||
ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(workspace_bytes), s); | ||
common_row_table = reinterpret_cast<IType*>(workspace.dptr_); | ||
temp_storage_ptr = workspace.dptr_ + common_row_table_bytes; | ||
mxnet_op::Kernel<set_zero, gpu>::Launch(s, num_rows, common_row_table); | ||
Kernel<MarkRspRowFlgKernel, gpu>::Launch( | ||
s, lhs_nz_rows, common_row_table, lhs_indices.dptr<IType>(), lhs_nz_rows); | ||
Kernel<MarkRspRowFlgKernel, gpu>::Launch( | ||
s, rhs_nz_rows, common_row_table, rhs_indices.dptr<IType>(), rhs_nz_rows); | ||
cub::DeviceScan::InclusiveSum(temp_storage_ptr, | ||
temp_storage_bytes, | ||
common_row_table, | ||
common_row_table, | ||
num_rows, | ||
mshadow::Stream<gpu>::GetStream(s)); | ||
nnvm::dim_t nnr_out = 0; | ||
CUDA_CALL(cudaMemcpy(&nnr_out, &common_row_table[num_rows-1], sizeof(nnvm::dim_t), | ||
cudaMemcpyDeviceToHost)); | ||
output.CheckAndAlloc({mshadow::Shape1(nnr_out)}); | ||
Kernel<FillRspRowIdxKernel, gpu>::Launch( | ||
s, num_rows, output.aux_data(kIdx).dptr<IType>(), common_row_table, num_rows); | ||
Kernel<set_zero, gpu>::Launch(s, nnr_out * num_cols, output.data().dptr<DType>()); | ||
Kernel<RspElemwiseKernel<mshadow_op::plus>, gpu>::Launch( | ||
s, lhs_nz_rows * num_cols, output.data().dptr<DType>(), common_row_table, | ||
lhs.data().dptr<DType>(), lhs_indices.dptr<IType>(), lhs_nz_rows, num_cols); | ||
Kernel<RspElemwiseKernel<OP>, gpu>::Launch( | ||
s, rhs_nz_rows * num_cols, output.data().dptr<DType>(), common_row_table, | ||
rhs.data().dptr<DType>(), rhs_indices.dptr<IType>(), rhs_nz_rows, num_cols); | ||
} else { | ||
if (lhs.storage_initialized()) { | ||
if (req == kWriteTo) { | ||
output.CheckAndAlloc({lhs.aux_shape(kIdx)}); | ||
Copy(output.data().FlatTo1D<gpu, DType>(), | ||
lhs.data().FlatTo1D<gpu, DType>(), s); | ||
Copy(output.aux_data(kIdx).FlatTo1D<gpu, IType>(), | ||
lhs.aux_data(kIdx).FlatTo1D<gpu, IType>(), s); | ||
} else if (req == kWriteInplace && rhs.IsSame(output)) { | ||
LOG(FATAL) << "Inplace on an empty rhs is not supported"; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what about kWriteInplace in all these branches? should we add a check? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Extra checks and tests added. |
||
} else if (rhs.storage_initialized()) { | ||
if (req == kWriteTo) { | ||
output.CheckAndAlloc({rhs.aux_shape(kIdx)}); | ||
} else if (req == kWriteInplace && lhs.IsSame(output)) { | ||
LOG(FATAL) << "Inplace on an empty lhs is not supported"; | ||
} | ||
if (std::is_same<OP, mshadow_op::minus>::value) { | ||
Kernel<op_with_req<mshadow_op::negation, kWriteTo>, gpu>::Launch( | ||
s, rhs.data().Size(), output.data().dptr<DType>(), rhs.data().dptr<DType>()); | ||
} else if (req == kWriteTo) { | ||
Copy(output.data().FlatTo1D<gpu, DType>(), | ||
rhs.data().FlatTo1D<gpu, DType>(), s); | ||
} | ||
if (req == kWriteTo) { | ||
Copy(output.aux_data(kIdx).FlatTo1D<gpu, IType>(), | ||
rhs.aux_data(kIdx).FlatTo1D<gpu, IType>(), s); | ||
} | ||
} else { | ||
FillZerosRspImpl(s, output); | ||
} | ||
} | ||
}); | ||
}); | ||
} | ||
|
||
NNVM_REGISTER_OP(elemwise_add) | ||
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, op::mshadow_op::plus>) | ||
.set_attr<FComputeEx>("FComputeEx<gpu>", ElemwiseBinaryOp::ComputeEx<gpu, op::mshadow_op::plus>); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@piiswrong I made the change here so that I can also call this function when I have a
const NDArray
object.