diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index d7c0574f962d..9f26e8f8ce94 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -543,8 +543,10 @@ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, self.epsilon = epsilon def create_state(self, index, weight): - return (zeros(weight.shape, weight.context, dtype=weight.dtype), # mean - zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance + return (zeros(weight.shape, weight.context, dtype=weight.dtype, + stype=weight.stype), # mean + zeros(weight.shape, weight.context, dtype=weight.dtype, + stype=weight.stype)) # variance def update(self, index, weight, grad, state): assert(isinstance(weight, NDArray)) @@ -649,11 +651,11 @@ def __init__(self, learning_rate=0.001, gamma1=0.9, gamma2=0.9, def create_state(self, index, weight): if self.centered: return ( - zeros(weight.shape, weight.context), # n - zeros(weight.shape, weight.context), # g - zeros(weight.shape, weight.context)) # delta + zeros(weight.shape, weight.context, stype=weight.stype), # n + zeros(weight.shape, weight.context, stype=weight.stype), # g + zeros(weight.shape, weight.context, stype=weight.stype)) # delta else: - return (zeros(weight.shape, weight.context),) # n + return (zeros(weight.shape, weight.context, stype=weight.stype),) # n def update(self, index, weight, grad, state): assert(isinstance(weight, NDArray)) diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index f112862f5048..9575d1d51dd8 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -223,7 +223,7 @@ void SetShapeType(const nnvm::Op* op, void SetDependency(std::vector *p_read_vars, std::vector *p_write_vars, std::vector *p_requested, - std::vector *p_auxidx, + std::vector *p_mutate_idx, const nnvm::Op* op, const nnvm::NodeAttrs& attrs, const Context& ctx, @@ -235,7 +235,7 @@ void SetDependency(std::vector *p_read_vars, std::vector& read_vars = *p_read_vars; std::vector& write_vars = *p_write_vars; std::vector& requested = *p_requested; - std::vector& auxidx = *p_auxidx; + std::vector& mutate_idx = *p_mutate_idx; if (tmp_resource.count(op)) { int ntmp = 0; @@ -261,9 +261,9 @@ void SetDependency(std::vector *p_read_vars, write_vars.push_back(i.var()); } if (mutate.count(op)) { - auxidx = mutate[op](attrs); - std::sort(auxidx.begin(), auxidx.end()); - for (auto & i : auxidx) { + mutate_idx = mutate[op](attrs); + std::sort(mutate_idx.begin(), mutate_idx.end()); + for (auto & i : mutate_idx) { write_vars.push_back(ndinputs[i].var()); } } @@ -293,36 +293,48 @@ void PushFCompute(const FCompute& fn, const std::vector& write_vars, const std::vector& requested, const std::vector& ndinputs, - const std::vector& ndoutputs) { + const std::vector& ndoutputs, + const std::vector& mutate_idx) { using namespace common; bool is_train = AutogradRuntime::Get()->IsTraining(); Engine::Get()->PushAsync( - [ctx, attrs, fn, ndinputs, ndoutputs, requested, is_train]( + [ctx, attrs, fn, ndinputs, ndoutputs, requested, is_train, mutate_idx]( RunContext rctx, engine::CallbackOnComplete on_complete) { std::vector input_blobs, output_blobs; - std::vector temp_in_src, temp_in_dst, temp_out_src, temp_out_dst; + // pre-fcompute and post-fcompute storage fallback src NDArrays and dst NDArrays + std::vector pre_temp_src, pre_temp_dst, post_temp_dst, post_temp_src; + // mapping from index in input_blobs to index in pre_temp_dst + std::unordered_map in_temp_idx_map; + // populate input blobs and output blobs + SetupDefaultBlobs(ndinputs, &input_blobs, &pre_temp_src, &pre_temp_dst, &in_temp_idx_map); + SetupDefaultBlobs(ndoutputs, &output_blobs, &post_temp_dst, &post_temp_src); + // add mutable inputs to post temp list + for (const auto idx : mutate_idx) { + if (in_temp_idx_map.find(idx) != in_temp_idx_map.end()) { + post_temp_src.push_back(pre_temp_dst[in_temp_idx_map[idx]]); + post_temp_dst.push_back(ndinputs[idx]); + } + } OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested}; - GetDefaultBlobs(ndinputs, &input_blobs, &temp_in_src, &temp_in_dst); - GetDefaultBlobs(ndoutputs, &output_blobs, &temp_out_src, &temp_out_dst); std::vector req(output_blobs.size(), kWriteTo); if (ctx.dev_mask() == gpu::kDevMask) { #if MXNET_USE_CUDA - CastNonDefaultStorage(temp_in_src, temp_in_dst, opctx); + CastNonDefaultStorage(pre_temp_src, pre_temp_dst, opctx); fn(attrs, opctx, input_blobs, req, output_blobs); // cast to original storage type, if necessary - CastNonDefaultStorage(temp_out_dst, temp_out_src, opctx); + CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx); rctx.get_stream()->Wait(); #else LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; #endif } else { - CastNonDefaultStorage(temp_in_src, temp_in_dst, opctx); + CastNonDefaultStorage(pre_temp_src, pre_temp_dst, opctx); fn(attrs, opctx, input_blobs, req, output_blobs); // cast to original storage type, if necessary - CastNonDefaultStorage(temp_out_dst, temp_out_src, opctx); + CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx); } on_complete(); }, ctx, read_vars, write_vars, FnProperty::kNormal, @@ -365,7 +377,8 @@ void PushOperator(const OpStatePtr& state, const std::vector& write_vars, const std::vector& requested, const std::vector& ndinputs, - const std::vector& ndoutputs) { + const std::vector& ndoutputs, + const std::vector& mutate_idx) { using namespace common; static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); @@ -379,28 +392,40 @@ void PushOperator(const OpStatePtr& state, if (fcompute != nullptr) { CHECK(exec_type == ExecType::kSync || exec_type == ExecType::kAsync); Engine::Get()->PushAsync( - [state, fcompute, ndinputs, ndoutputs, requested, is_train, exec_type]( + [state, fcompute, ndinputs, ndoutputs, requested, is_train, exec_type, mutate_idx]( RunContext rctx, engine::CallbackOnComplete on_complete) { OpContext opctx{is_train, rctx, on_complete, requested}; std::vector input_blobs, output_blobs; - std::vector temp_in_src, temp_in_dst, temp_out_src, temp_out_dst; - GetDefaultBlobs(ndinputs, &input_blobs, &temp_in_src, &temp_in_dst); - GetDefaultBlobs(ndoutputs, &output_blobs, &temp_out_src, &temp_out_dst); + // pre-fcompute and post-fcompute storage fallback src NDArrays and dst NDArrays + std::vector pre_temp_src, pre_temp_dst, post_temp_dst, post_temp_src; + // mapping from index in input_blobs to index in pre_temp_dst + std::unordered_map in_temp_idx_map; + // populate input blobs and output blobs + SetupDefaultBlobs(ndinputs, &input_blobs, &pre_temp_src, &pre_temp_dst, &in_temp_idx_map); + SetupDefaultBlobs(ndoutputs, &output_blobs, &post_temp_dst, &post_temp_src); + // add mutable inputs to post temp list + for (const auto idx : mutate_idx) { + if (in_temp_idx_map.find(idx) != in_temp_idx_map.end()) { + post_temp_src.push_back(pre_temp_dst[in_temp_idx_map[idx]]); + post_temp_dst.push_back(ndinputs[idx]); + } + } std::vector req(output_blobs.size(), kWriteTo); if (rctx.get_ctx().dev_mask() == gpu::kDevMask) { #if MXNET_USE_CUDA - CastNonDefaultStorage(temp_in_src, temp_in_dst, opctx); + CastNonDefaultStorage(pre_temp_src, pre_temp_dst, opctx); fcompute(state, opctx, input_blobs, req, output_blobs); - CastNonDefaultStorage(temp_out_dst, temp_out_src, opctx); + CastNonDefaultStorage(temp_out_dst, post_temp_dst, opctx); +>>>>>>> include mutatable inputs in storage fallback. refactor executor #else LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; #endif } else { - CastNonDefaultStorage(temp_in_src, temp_in_dst, opctx); + CastNonDefaultStorage(pre_temp_src, pre_temp_dst, opctx); fcompute(state, opctx, input_blobs, req, output_blobs); - CastNonDefaultStorage(temp_out_dst, temp_out_src, opctx); + CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx); } if (exec_type == ExecType::kSync) { if (rctx.get_ctx().dev_mask() == gpu::kDevMask) { @@ -463,8 +488,8 @@ void ImperativeInvokeImpl(const Context& default_ctx, std::vector read_vars, write_vars; std::vector requested; - std::vector auxidx; - SetDependency(&read_vars, &write_vars, &requested, &auxidx, + std::vector mutate_idx; + SetDependency(&read_vars, &write_vars, &requested, &mutate_idx, op, attrs, ctx, ndinputs, ndoutputs); FCompute fn = common::GetFCompute(op, "FCompute", ctx); @@ -482,7 +507,7 @@ void ImperativeInvokeImpl(const Context& default_ctx, attrs, &ndinputs, &ndoutputs); } PushFCompute(fn, op, attrs, ctx, read_vars, write_vars, - requested, ndinputs, ndoutputs); + requested, ndinputs, ndoutputs, mutate_idx); } else if (createop.count(op)) { auto state = createop[op](attrs, ctx, ret->arg_shapes, ret->arg_types); @@ -492,7 +517,7 @@ void ImperativeInvokeImpl(const Context& default_ctx, } write_vars.push_back(state.get_var()); PushOperator(state, op, attrs, ctx, read_vars, write_vars, - requested, ndinputs, ndoutputs); + requested, ndinputs, ndoutputs, mutate_idx); } else { LOG(FATAL) << "Operator " << op->name << " is not implemented for " diff --git a/src/common/utils.h b/src/common/utils.h index 86bc4a730d6b..b41510ccb091 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -31,18 +31,30 @@ template void CastStorageDispatch(const OpContext& ctx, const NDArray& input, const NDArray& output); /* - * \brief get the corresponding tensor blobs from default storage NDArrays. - * If any NDArray is of non-default storage, it will be added to `temp_src` - * \return true if any input storage needs to be casted + * \brief setup default-storage tblobs from source NDArrays. If any source NDArray has non-default + * storage, it creates a temp NDArray with default storage and uses the temp tblob. The + * function also records the indices of non-default source NDArrays and the indices of + * their corresponding temporary NDArrays in the temp array. + * \param src list of source NDArray + * \param blobs list of tblobs to return + * \param temp_src list of source NDArrays which requires temporary default storage representation + * \param temp_dst list of temporary destination NDArrays for default storage representation + * \param idx_map mapping from indices in source NDArrays to indices in temp_dst. When not set, + indices are not recorded + * \return true if any source NDArray need to cast storage */ -inline bool GetDefaultBlobs(const std::vector& src, - std::vector *blobs, - std::vector *temp_src, - std::vector *temp_dst) { +inline bool SetupDefaultBlobs(const std::vector& src, + std::vector *blobs, + std::vector *temp_src, + std::vector *temp_dst, + std::unordered_map *idx_map = nullptr) { bool require_cast = false; for (size_t i = 0; i < src.size(); i++) { auto& nd = src[i]; if (nd.storage_type() != kDefaultStorage) { + if (idx_map != nullptr) { + (*idx_map)[i] = temp_dst->size(); + } NDArray temp(nd.shape(), nd.ctx(), false); temp_src->emplace_back(nd); temp_dst->emplace_back(temp); @@ -56,10 +68,15 @@ inline bool GetDefaultBlobs(const std::vector& src, } /* - * \brief cast the NDArrays in `src` to NDArrays in `dst`. This is only used - * for storage fallback mechanism in executor. + * \brief cast the NDArrays in `src` and store the result in NDArrays in `dst`. + * This is only used for storage fallback in executor. * When storage_fallback is false, and `MXNET_EXEC_STORAGE_FALLBACK` == 0, * storage fallback is disallowed. + * \param src list of source NDArray to cast + * \param dst list of destionation NDArray which hold the result of cast_storage operation + * \param ctx operator context for cast_storage operation + * \param storage_fallback whether storage_fallback is allowed. When set to false, + * its value depends on `MXNET_EXEC_STORAGE_FALLBACK`. */ template inline void CastNonDefaultStorage(const std::vector& src, @@ -89,6 +106,7 @@ inline bool ContainsNonDefaultStorage(const StorageTypeVector& vstorage) { return false; } +// Check if any NDArray in the list has default storage inline bool ContainsDefaultStorage(const std::vector& ndarrays) { for (const auto &nd : ndarrays) { if (nd.storage_type() == kDefaultStorage) { diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index acc9ba41a710..f387b4bc9c65 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -23,42 +23,85 @@ const OperatorProperty* OpPropGetOpProperty(const NodeAttrs& attrs); namespace exec { -// stateful compute executor -class StatefulComputeExecutor : public OpExecutor { +// abstract OpExecutor which provides storage fallback procedure on +// non-default inputs and outputs +class StorageFallbackOpExecutor : public OpExecutor { public: - void Run(RunContext rctx, bool is_gpu) override { + explicit StorageFallbackOpExecutor(const std::vector &mutate_idx) + : mutate_idx_(mutate_idx) {} + + void Setup() override { + using namespace common; + in_data_.clear(); out_data_.clear(); + pre_temp_src_.clear(); pre_temp_dst_.clear(); + post_temp_src_.clear(); post_temp_dst_.clear(); + in_temp_idx_map_.clear(); + SetupDefaultBlobs(in_array, &in_data_, &pre_temp_src_, &pre_temp_dst_, &in_temp_idx_map_); + SetupDefaultBlobs(out_array, &out_data_, &post_temp_dst_, &post_temp_src_); + for (const auto idx : mutate_idx_) { + if (in_temp_idx_map_.find(idx) != in_temp_idx_map_.end()) { + post_temp_src_.push_back(pre_temp_dst_[in_temp_idx_map_[idx]]); + post_temp_dst_.push_back(in_array[idx]); + } + } + } + + protected: + // storage fallback before fcompute is launched + void PreFCompute(bool is_gpu) { + using namespace common; + if (is_gpu) { +#if MXNET_USE_CUDA + CastNonDefaultStorage(pre_temp_src_, pre_temp_dst_, op_ctx); +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } else { + CastNonDefaultStorage(pre_temp_src_, pre_temp_dst_, op_ctx); + } + } + + // storage fallback after fcompute is completed + void PostFCompute(bool is_gpu) { using namespace common; - op_ctx.run_ctx = rctx; if (is_gpu) { #if MXNET_USE_CUDA - CastNonDefaultStorage(temp_in_src_, temp_in_dst_, op_ctx); - CastNonDefaultStorage(temp_out_src_, temp_out_dst_, op_ctx); - fcompute_(state_, op_ctx, in_data_, req, out_data_); - CastNonDefaultStorage(temp_out_dst_, temp_out_src_, op_ctx); + CastNonDefaultStorage(post_temp_src_, post_temp_dst_, op_ctx); #else LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; #endif } else { - CastNonDefaultStorage(temp_in_src_, temp_in_dst_, op_ctx); - CastNonDefaultStorage(temp_out_src_, temp_out_dst_, op_ctx); - fcompute_(state_, op_ctx, in_data_, req, out_data_); - CastNonDefaultStorage(temp_out_dst_, temp_out_src_, op_ctx); + CastNonDefaultStorage(post_temp_src_, post_temp_dst_, op_ctx); } + } + + // default storage tensor blobs for fcompute + std::vector in_data_, out_data_; + // source NDArray for cast storage + std::vector pre_temp_src_, post_temp_src_; + // destination NDArray for cast storage + std::vector pre_temp_dst_, post_temp_dst_; + // mapping from index in input_blobs to index in pre_temp_dst + std::unordered_map in_temp_idx_map_; + // indices of mutatable inputs + std::vector mutate_idx_; +}; + + +// stateful compute executor +class StatefulComputeExecutor : public StorageFallbackOpExecutor { + public: + void Run(RunContext rctx, bool is_gpu) override { + op_ctx.run_ctx = rctx; + PreFCompute(is_gpu); + fcompute_(state_, op_ctx, in_data_, req, out_data_); + PostFCompute(is_gpu); #if MKL_EXPERIMENTAL == 1 mkl_tblobs_prv_to_cpu(in_data_); mkl_tblobs_prv_to_cpu(out_data_); #endif } - void Setup() override { - using namespace common; - in_data_.clear(); out_data_.clear(); - temp_in_src_.clear(); temp_in_dst_.clear(); - temp_out_src_.clear(); temp_out_dst_.clear(); - GetDefaultBlobs(in_array, &in_data_, &temp_in_src_, &temp_in_dst_); - GetDefaultBlobs(out_array, &out_data_, &temp_out_src_, &temp_out_dst_); - } - ExecType exec_type() const override { return exec_type_; } @@ -69,16 +112,16 @@ class StatefulComputeExecutor : public OpExecutor { explicit StatefulComputeExecutor(const OpStatePtr& state, const FStatefulCompute& fcompute, - ExecType exec_type) - : state_(state), fcompute_(fcompute), exec_type_(exec_type) {} + ExecType exec_type, + const std::vector &mutate_idx) + : StorageFallbackOpExecutor(mutate_idx), + state_(state), fcompute_(fcompute), exec_type_(exec_type) {} private: friend Graph AttachOpExecs(Graph g); OpStatePtr state_; FStatefulCompute fcompute_; ExecType exec_type_; - std::vector in_data_, out_data_; - std::vector temp_in_src_, temp_in_dst_, temp_out_src_, temp_out_dst_; }; @@ -114,57 +157,34 @@ class StatefulComputeExExecutor : public OpExecutor { // fcompute executor -class FComputeExecutor : public OpExecutor { +class FComputeExecutor : public StorageFallbackOpExecutor { public: void Run(RunContext rctx, bool is_gpu) override { using namespace common; - // TODO(haibin) avoid repeating this if all inputs are already in default-storage op_ctx.run_ctx = rctx; - if (is_gpu) { -#if MXNET_USE_CUDA - CastNonDefaultStorage(temp_in_src_, temp_in_dst_, op_ctx); - CastNonDefaultStorage(temp_out_src_, temp_out_dst_, op_ctx); - fcompute_(attrs_, op_ctx, in_data_, req, out_data_); - CastNonDefaultStorage(temp_out_dst_, temp_out_src_, op_ctx); -#else - LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; -#endif - } else { - CastNonDefaultStorage(temp_in_src_, temp_in_dst_, op_ctx); - CastNonDefaultStorage(temp_out_src_, temp_out_dst_, op_ctx); - fcompute_(attrs_, op_ctx, in_data_, req, out_data_); - CastNonDefaultStorage(temp_out_dst_, temp_out_src_, op_ctx); - } + PreFCompute(is_gpu); + fcompute_(attrs_, op_ctx, in_data_, req, out_data_); + PostFCompute(is_gpu); #if MKL_EXPERIMENTAL == 1 mkl_tblobs_prv_to_cpu(in_data_); mkl_tblobs_prv_to_cpu(out_data_); #endif } - void Setup() override { - using namespace common; - in_data_.clear(); out_data_.clear(); - temp_in_src_.clear(); temp_in_dst_.clear(); - temp_out_src_.clear(); temp_out_dst_.clear(); - GetDefaultBlobs(in_array, &in_data_, &temp_in_src_, &temp_in_dst_); - GetDefaultBlobs(out_array, &out_data_, &temp_out_src_, &temp_out_dst_); - } - ExecType exec_type() const override { return exec_type_; } explicit FComputeExecutor(const NodeAttrs& attrs, FCompute fcompute, - ExecType exec_type) - : attrs_(attrs), fcompute_(fcompute), exec_type_(exec_type) { + ExecType exec_type, const std::vector &mutate_idx) + : StorageFallbackOpExecutor(mutate_idx), + attrs_(attrs), fcompute_(fcompute), exec_type_(exec_type) { } private: NodeAttrs attrs_; FCompute fcompute_; ExecType exec_type_; - std::vector in_data_, out_data_; - std::vector temp_in_src_, temp_in_dst_, temp_out_src_, temp_out_dst_; }; // fcompute_ex executor @@ -247,7 +267,8 @@ Graph AttachOpExecs(Graph g) { FStatefulCompute fcompute = common::GetFCompute( op, "FStatefulCompute", vctx[i]); if (fcompute != nullptr) { - ret[i] = std::make_shared(state, fcompute, exec_type); + ret[i] = std::make_shared(state, fcompute, + exec_type, mutate_index); } else { FStatefulComputeEx fcompute_ex = common::GetFCompute( op, "FStatefulComputeEx", vctx[i]); @@ -266,7 +287,7 @@ Graph AttachOpExecs(Graph g) { if (fcompute != nullptr) { ret[i] = std::make_shared( dynamic_cast(ret[fwd_id].get())->state_, - fcompute, exec_type); + fcompute, exec_type, mutate_index); } else { FStatefulComputeEx fcompute_ex = common::GetFCompute( op, "FStatefulComputeEx", vctx[i]); @@ -285,7 +306,7 @@ Graph AttachOpExecs(Graph g) { inode.source->attrs, fcomp_ex, exec_type); } else if (fcompute != nullptr) { ret[i] = std::make_shared( - inode.source->attrs, fcompute, exec_type); + inode.source->attrs, fcompute, exec_type, mutate_index); } else { LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name; } diff --git a/src/nnvm/legacy_op_util.cc b/src/nnvm/legacy_op_util.cc index 2bba5f1c3655..6e601780080b 100644 --- a/src/nnvm/legacy_op_util.cc +++ b/src/nnvm/legacy_op_util.cc @@ -60,19 +60,20 @@ class OperatorState { opr_ = opr; fwd_init_ = bwd_init_ = false; - in_data_.resize(prop->ListArguments().size()); + in_data_fwd_.resize(prop->ListArguments().size()); + in_data_bwd_.resize(prop->ListArguments().size()); out_data_.resize(prop->NumOutputs()); aux_data_.resize(prop->ListAuxiliaryStates().size()); - in_grad_.resize(in_data_.size()); + in_grad_.resize(in_data_fwd_.size()); out_grad_.resize(prop->NumVisibleOutputs()); std::vector out_grad_ptr(out_grad_.size()); for (size_t i = 0; i < out_grad_.size(); ++i) { out_grad_ptr[i] = &out_grad_[i]; } - std::vector in_data_ptr(in_data_.size()); - for (size_t i = 0; i < in_data_.size(); ++i) { - in_data_ptr[i] = &in_data_[i]; + std::vector in_data_ptr(in_data_fwd_.size()); + for (size_t i = 0; i < in_data_fwd_.size(); ++i) { + in_data_ptr[i] = &in_data_bwd_[i]; } std::vector out_data_ptr(out_data_.size()); for (size_t i = 0; i < out_data_.size(); ++i) { @@ -89,16 +90,19 @@ class OperatorState { const std::vector& req, const std::vector& outputs) { if (!fwd_init_) { - CHECK_EQ(inputs.size(), in_data_.size() + aux_data_.size()); + CHECK_EQ(inputs.size(), in_data_fwd_.size() + aux_data_.size()); CHECK_EQ(outputs.size(), out_data_.size()); - for (size_t i = 0; i < in_data_.size(); ++i) in_data_[i] = inputs[i]; + // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones + // referred by arg_data_ptr_ will be overriden + for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] = inputs[i]; + for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] = inputs[i]; for (size_t i = 0; i < aux_data_.size(); ++i) { - aux_data_[i] = inputs[i + in_data_.size()]; + aux_data_[i] = inputs[i + in_data_fwd_.size()]; } for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i]; fwd_init_ = true; } - opr_->Forward(ctx, in_data_, req, out_data_, aux_data_); + opr_->Forward(ctx, in_data_fwd_, req, out_data_, aux_data_); } void Backward(const OpContext &ctx, @@ -108,6 +112,8 @@ class OperatorState { if (!bwd_init_) { CHECK(fwd_init_); CHECK_EQ(arg_data_ptr_.size() + aux_data_.size(), inputs.size()); + // override tblobs pointed by arg_data_ptr_ since they might not contain + // initialized data during forward pass. for (size_t i = 0; i < arg_data_ptr_.size(); ++i) { *arg_data_ptr_[i] = inputs[i]; } @@ -118,13 +124,19 @@ class OperatorState { for (size_t i = 0; i < outputs.size(); ++i) in_grad_[i] = outputs[i]; bwd_init_ = true; } - opr_->Backward(ctx, out_grad_, in_data_, out_data_, req, in_grad_, aux_data_); + opr_->Backward(ctx, out_grad_, in_data_bwd_, out_data_, req, in_grad_, aux_data_); } private: Operator *opr_; bool fwd_init_, bwd_init_; - std::vector in_data_, aux_data_, out_data_, in_grad_, out_grad_; + // input data blobs for forward and backward + // in_data_fwd_ and in_data_bwd_ will hold different tblobs when StorageFallbackOpExecutor + // performs storage fallback on a non-default input NDArray. The one in in_data_fwd_ is + // generated when setting up forward executor, while the one in in_data_bwd_ is generated + // when setting up backward executor. + std::vector in_data_fwd_, in_data_bwd_; + std::vector aux_data_, out_data_, in_grad_, out_grad_; std::vector arg_data_ptr_; }; diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index e31046afb0c1..270aae592f14 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -360,8 +360,8 @@ void FCompExFallback(const nnvm::NodeAttrs& attrs, using namespace mxnet::common; std::vector in_blobs, out_blobs; std::vector temp_in_src, temp_in_dst, temp_out_src, temp_out_dst; - GetDefaultBlobs(inputs, &in_blobs, &temp_in_src, &temp_in_dst); - GetDefaultBlobs(outputs, &out_blobs, &temp_out_src, &temp_out_dst); + SetupDefaultBlobs(inputs, &in_blobs, &temp_in_src, &temp_in_dst); + SetupDefaultBlobs(outputs, &out_blobs, &temp_out_src, &temp_out_dst); CastNonDefaultStorage(temp_in_src, temp_in_dst, ctx, true); fcompute(attrs, ctx, in_blobs, req, out_blobs); CastNonDefaultStorage(temp_out_dst, temp_out_src, ctx, true); diff --git a/tests/cpp/operator/ndarray_test.cc b/tests/cpp/operator/ndarray_test.cc deleted file mode 100644 index f2ed30793881..000000000000 --- a/tests/cpp/operator/ndarray_test.cc +++ /dev/null @@ -1,6 +0,0 @@ -/*! - * Copyright (c) 2017 by Contributors - * \file ndarray_test.cc - * \brief ndarray unit test utility functions - * \author Haibin Lin -*/ diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index 8399194efbba..b6c902b13370 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -358,6 +358,7 @@ def test_adam(): {'rescale_grad': 0.8, 'wd': 0.05}] for kwarg in kwargs: compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, np.float32) + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, np.float32, g_stype='row_sparse') # RMSProp class PyRMSProp(mx.optimizer.Optimizer): @@ -498,6 +499,7 @@ def test_rms(): {'rescale_grad': 0.8, 'wd': 0.05, 'centered': True, 'clip_weights': 0.01}] for kwarg in kwargs: compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, np.float32) + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, np.float32, g_stype='row_sparse') if __name__ == '__main__': test_adam()