Skip to content

Commit

Permalink
fix small memory leak of sparse embedding (apache#9025)
Browse files Browse the repository at this point in the history
* disable empty output of ndarray.slice & fix small mem leak of sparse embedding

* revert

* replace cudamalloc with resource request
  • Loading branch information
ZiyueHuang authored and piiswrong committed Dec 12, 2017
1 parent dc75bcb commit cc34358
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
3 changes: 2 additions & 1 deletion src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@ namespace mxnet {
namespace op {

template<>
void SparseEmbeddingOpForwardRspImpl<cpu>(mshadow::Stream<cpu>* s,
void SparseEmbeddingOpForwardRspImpl<cpu>(const OpContext& ctx,
const TBlob& data,
const NDArray& weight,
const OpReqType req,
const TBlob& output) {
if (req == kNullOp) return;
using namespace rowsparse;
using namespace mxnet_op;
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
// zeros weight
if (req == kWriteTo && !weight.storage_initialized()) {
size_t out_size = output.shape_.Size();
Expand Down
8 changes: 5 additions & 3 deletions src/operator/tensor/indexing_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,15 @@ struct AddTakeGradRspGPUKernel {
};

template<>
void SparseEmbeddingOpForwardRspImpl<gpu>(mshadow::Stream<gpu>* s,
void SparseEmbeddingOpForwardRspImpl<gpu>(const OpContext& ctx,
const TBlob& data,
const NDArray& weight,
const OpReqType req,
const TBlob& output) {
if (req == kNullOp) return;
using namespace rowsparse;
using namespace mxnet_op;
mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
// zeros weight
if (req == kWriteTo && !weight.storage_initialized()) {
size_t out_size = output.shape_.Size();
Expand All @@ -85,8 +86,9 @@ void SparseEmbeddingOpForwardRspImpl<gpu>(mshadow::Stream<gpu>* s,
DType max = static_cast<DType>(weight.shape()[0] - 1);
DType* data_ptr = data.dptr<DType>();
size_t data_size = data.shape_.Size();
int32_t* is_valid_ptr = NULL;
CUDA_CALL(cudaMalloc(&is_valid_ptr, sizeof(int32_t)));
Tensor<gpu, 1, char> workspace = ctx.requested[0]
.get_space_typed<gpu, 1, char>(Shape1(sizeof(int32_t)), s);
int32_t* is_valid_ptr = reinterpret_cast<int32_t*>(workspace.dptr_);
Kernel<set_zero, gpu>::Launch(s, 1, is_valid_ptr);
Kernel<is_valid_check, gpu>::Launch(s, data_size, is_valid_ptr, data_ptr, min, max);
CUDA_CALL(cudaMemcpy(&is_valid, is_valid_ptr, sizeof(int32_t),
Expand Down
5 changes: 2 additions & 3 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ inline void EmbeddingOpForwardRspImpl(mshadow::Stream<xpu>* s,

// Embedding forward implementation with row_sparse weight
template<typename xpu>
void SparseEmbeddingOpForwardRspImpl(mshadow::Stream<xpu>* s,
void SparseEmbeddingOpForwardRspImpl(const OpContext& ctx,
const TBlob& data,
const NDArray& weight,
const OpReqType req,
Expand Down Expand Up @@ -406,10 +406,9 @@ void SparseEmbeddingOpForwardEx(const nnvm::NodeAttrs& attrs,
const auto data_stype = data.storage_type();
const auto weight_stype = weight.storage_type();
const auto out_stype = out.storage_type();
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
if (data_stype == kDefaultStorage && weight_stype == kRowSparseStorage &&
out_stype == kDefaultStorage) {
SparseEmbeddingOpForwardRspImpl<xpu>(s, data.data(), weight, req[0], out.data());
SparseEmbeddingOpForwardRspImpl<xpu>(ctx, data.data(), weight, req[0], out.data());
} else {
LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
}
Expand Down

0 comments on commit cc34358

Please sign in to comment.