diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 7d885ad47386..735da31b8b41 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -29,7 +29,7 @@ namespace mxnet { namespace op { template<> -void SparseEmbeddingOpForwardRspImpl(mshadow::Stream* s, +void SparseEmbeddingOpForwardRspImpl(const OpContext& ctx, const TBlob& data, const NDArray& weight, const OpReqType req, @@ -37,6 +37,7 @@ void SparseEmbeddingOpForwardRspImpl(mshadow::Stream* s, if (req == kNullOp) return; using namespace rowsparse; using namespace mxnet_op; + mshadow::Stream *s = ctx.get_stream(); // zeros weight if (req == kWriteTo && !weight.storage_initialized()) { size_t out_size = output.shape_.Size(); diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu index f029f0209957..4021f2b3a217 100644 --- a/src/operator/tensor/indexing_op.cu +++ b/src/operator/tensor/indexing_op.cu @@ -61,7 +61,7 @@ struct AddTakeGradRspGPUKernel { }; template<> -void SparseEmbeddingOpForwardRspImpl(mshadow::Stream* s, +void SparseEmbeddingOpForwardRspImpl(const OpContext& ctx, const TBlob& data, const NDArray& weight, const OpReqType req, @@ -69,6 +69,7 @@ void SparseEmbeddingOpForwardRspImpl(mshadow::Stream* s, if (req == kNullOp) return; using namespace rowsparse; using namespace mxnet_op; + mshadow::Stream* s = ctx.get_stream(); // zeros weight if (req == kWriteTo && !weight.storage_initialized()) { size_t out_size = output.shape_.Size(); @@ -85,8 +86,9 @@ void SparseEmbeddingOpForwardRspImpl(mshadow::Stream* s, DType max = static_cast(weight.shape()[0] - 1); DType* data_ptr = data.dptr(); size_t data_size = data.shape_.Size(); - int32_t* is_valid_ptr = NULL; - CUDA_CALL(cudaMalloc(&is_valid_ptr, sizeof(int32_t))); + Tensor workspace = ctx.requested[0] + .get_space_typed(Shape1(sizeof(int32_t)), s); + int32_t* is_valid_ptr = reinterpret_cast(workspace.dptr_); Kernel::Launch(s, 1, is_valid_ptr); Kernel::Launch(s, data_size, is_valid_ptr, data_ptr, min, max); CUDA_CALL(cudaMemcpy(&is_valid, is_valid_ptr, sizeof(int32_t), diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index b0f06de9ae48..4043e76cfdae 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -364,7 +364,7 @@ inline void EmbeddingOpForwardRspImpl(mshadow::Stream* s, // Embedding forward implementation with row_sparse weight template -void SparseEmbeddingOpForwardRspImpl(mshadow::Stream* s, +void SparseEmbeddingOpForwardRspImpl(const OpContext& ctx, const TBlob& data, const NDArray& weight, const OpReqType req, @@ -406,10 +406,9 @@ void SparseEmbeddingOpForwardEx(const nnvm::NodeAttrs& attrs, const auto data_stype = data.storage_type(); const auto weight_stype = weight.storage_type(); const auto out_stype = out.storage_type(); - mshadow::Stream *s = ctx.get_stream(); if (data_stype == kDefaultStorage && weight_stype == kRowSparseStorage && out_stype == kDefaultStorage) { - SparseEmbeddingOpForwardRspImpl(s, data.data(), weight, req[0], out.data()); + SparseEmbeddingOpForwardRspImpl(ctx, data.data(), weight, req[0], out.data()); } else { LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs); }