From 1637011c2a9e168feefafe60573bb2181c3a5d19 Mon Sep 17 00:00:00 2001 From: Eric Junyuan Xie Date: Mon, 2 Oct 2017 11:48:51 -0700 Subject: [PATCH] Autograd with multiple devices (#8124) * add cross device autograd * place device --- include/mxnet/imperative.h | 78 ++++++------- include/mxnet/ndarray.h | 13 +++ src/imperative/cached_op.cc | 10 +- src/imperative/imperative.cc | 53 +++++---- src/imperative/imperative_utils.h | 90 ++++++++++++--- src/kvstore/comm.h | 2 +- src/ndarray/ndarray.cc | 146 ++++++++++++++---------- src/ndarray/ndarray_function.cu | 2 +- src/operator/operator_common.h | 5 + src/operator/tensor/init_op.h | 20 ++-- src/operator/tensor/sparse_retain-inl.h | 4 +- src/operator/tensor/square_sum-inl.h | 4 +- tests/python/gpu/test_operator_gpu.py | 26 +++++ 13 files changed, 298 insertions(+), 155 deletions(-) diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index ead02914..d26e86f4 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -36,6 +36,45 @@ namespace mxnet { /*! \brief runtime functions for NDArray */ class Imperative { public: + /*! \brief */ + class AGInfo { + public: + Context ctx; + OpReqType grad_req; + OpStatePtr state; + std::vector outputs; + std::vector out_grads; + bool fresh_out_grad; + + AGInfo() : + grad_req(kNullOp), fresh_out_grad(false) {} + + static void Clear(const nnvm::NodePtr& node) { + if (node == nullptr || node->info.empty()) return; + AGInfo& info = Get(node); + if (info.grad_req != kNullOp) return; + node->info.clear(); + } + + static AGInfo& Get(const nnvm::NodePtr& node) { + return dmlc::get(node->info); + } + + static AGInfo& Create(const nnvm::NodePtr& node) { + node->info.construct(); + return Get(node); + } + + static bool IsNone(const NDArray& arr) { + return arr.entry_.node == nullptr || arr.entry_.node->info.empty(); + } + + static bool IsVariable(const nnvm::NodePtr& node) { + AGInfo& info = Get(node); + return info.grad_req != kNullOp && info.outputs.size() == 1 + && info.out_grads.size() == 1; + } + }; class CachedOp { public: explicit CachedOp(const nnvm::Symbol& sym); @@ -141,44 +180,6 @@ class Imperative { private: friend class NDArray; - /*! \brief */ - class AGInfo { - public: - OpReqType grad_req; - OpStatePtr state; - std::vector outputs; - std::vector out_grads; - bool fresh_out_grad; - - AGInfo() : - grad_req(kNullOp), fresh_out_grad(false) {} - - static void Clear(const nnvm::NodePtr& node) { - if (node == nullptr || node->info.empty()) return; - AGInfo& info = Get(node); - if (info.grad_req != kNullOp) return; - node->info.clear(); - } - - static AGInfo& Get(const nnvm::NodePtr& node) { - return dmlc::get(node->info); - } - - static AGInfo& Create(const nnvm::NodePtr& node) { - node->info.construct(); - return Get(node); - } - - static bool IsNone(const NDArray& arr) { - return arr.entry_.node == nullptr || arr.entry_.node->info.empty(); - } - - static bool IsVariable(const nnvm::NodePtr& node) { - AGInfo& info = Get(node); - return info.grad_req != kNullOp && info.outputs.size() == 1 - && info.out_grads.size() == 1; - } - }; /*! \brief make constructor protected. */ Imperative() {} /*! \brief find the input/output ndarrays that are needed for backward */ @@ -189,7 +190,6 @@ class Imperative { std::vector *p_save_outputs); void RunGraph( const bool retain_graph, - const Context& default_ctx, const nnvm::IndexedGraph& idx, const std::vector arrays, size_t node_start, size_t node_end, diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 0af48b5e..84ee9fa5 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -892,6 +892,19 @@ size_t num_aux_data(NDArrayStorageType stype); */ void CopyFromTo(const NDArray &from, NDArray *to, int priority = 0); +/*! + * \brief issue an copy operation from one NDArray to another + * the two ndarray can sit on different devices + * this operation will be scheduled by the engine + * + * \param from the ndarray we want to copy data from + * \param to the target ndarray + * \param priority Priority of the action. + * \note The function name explicitly marks the order of from and to + * due to different possible convention carried by copy function. + */ +void CopyFromTo(const NDArray &from, const NDArray& to, int priority = 0); + /*! * \brief Perform elementwise sum over each data from source, store result into out. * \param source the ndarray we want to sum diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index f6f8afb9..c3653719 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -195,7 +195,8 @@ nnvm::Graph Imperative::CachedOp::GetForwardGraph( bool match = true; match &= CheckAndInferShape(&g, std::move(shape_inputs), true); match &= CheckAndInferType(&g, std::move(dtype_inputs), true); - match &= CheckAndInferStorageType(&g, inputs[0]->ctx().dev_mask(), + exec::DevMaskVector dev_mask(g.indexed_graph().num_nodes(), inputs[0]->ctx().dev_mask()); + match &= CheckAndInferStorageType(&g, std::move(dev_mask), std::move(storage_type_inputs), true); if (!match) { @@ -282,7 +283,8 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph( node_range, entry_range); match &= CheckAndInferType(&g, std::move(dtypes), false, node_range, entry_range); - match &= CheckAndInferStorageType(&g, inputs[0]->ctx().dev_mask(), std::move(stypes), + exec::DevMaskVector dev_mask(idx.num_nodes(), inputs[0]->ctx().dev_mask()); + match &= CheckAndInferStorageType(&g, std::move(dev_mask), std::move(stypes), false, node_range, entry_range); if (!match) { @@ -352,7 +354,7 @@ OpStatePtr Imperative::CachedOp::Forward(const std::vector& inputs, const auto& dispatch_modes = g.GetAttr("dispatch_mode"); Imperative::Get()->RunGraph( - false, default_ctx, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), + false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), std::move(ref_count), &states, dispatch_modes); for (size_t i = 0; i < idx.num_node_entries(); ++i) { @@ -422,7 +424,7 @@ void Imperative::CachedOp::Backward( const auto& dispatch_modes = g.GetAttr("dispatch_mode"); Imperative::Get()->RunGraph( - retain_graph, default_ctx, idx, arrays, num_forward_nodes, idx.num_nodes(), + retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(), std::move(array_reqs), std::move(ref_count), &states, dispatch_modes); if (retain_graph) { diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index 48050ec8..fc35c492 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -122,11 +122,13 @@ void Imperative::MarkVariables( info.outputs.emplace_back(variables[i]->Detach()); info.out_grads.emplace_back(gradients[i]->Detach()); info.grad_req = static_cast(grad_reqs[i]); + info.ctx = variables[i]->ctx(); gradients[i]->entry_ = nnvm::NodeEntry{ nnvm::Symbol::CreateVariable("grad" + str_c).outputs[0].node, 0, 0}; AGInfo& grad_info = AGInfo::Create(gradients[i]->entry_.node); grad_info.outputs.emplace_back(gradients[i]->Detach()); + grad_info.ctx = gradients[i]->ctx(); } } @@ -207,6 +209,7 @@ void Imperative::RecordOp( node->attrs.name = "node_" + std::to_string(node_count_++); AGInfo& info = AGInfo::Create(node); info.state = state; + info.ctx = outputs[0]->ctx(); if (p_save_inputs == nullptr) { p_save_inputs = &(local_buff->save_inputs); @@ -225,6 +228,7 @@ void Imperative::RecordOp( nnvm::NodeEntry entry{nnvm::Symbol::CreateVariable( "null" + std::to_string(variable_count_++)).outputs[0].node, 0, 0}; AGInfo& input_info = AGInfo::Create(entry.node); + input_info.ctx = inputs[i]->ctx(); if (save_inputs[i]) { input_info.outputs.emplace_back(*inputs[i]); } else { @@ -263,7 +267,6 @@ void Imperative::RecordOp( void Imperative::RunGraph( const bool retain_graph, - const Context& default_ctx, const nnvm::IndexedGraph& idx, const std::vector arrays, size_t node_start, size_t node_end, @@ -288,6 +291,7 @@ void Imperative::RunGraph( for (size_t i = node_start; i < node_end; ++i) { const nnvm::IndexedGraph::Node& node = idx[i]; if (node.source->op() == nullptr) continue; + auto num_outputs = node.source->num_outputs(); ndinputs.clear(); ndinputs.reserve(node.inputs.size()); for (const auto& j : node.inputs) { @@ -295,15 +299,16 @@ void Imperative::RunGraph( CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name << " " << j.index; } ndoutputs.clear(); - ndoutputs.reserve(node.source->num_outputs()); + ndoutputs.reserve(num_outputs); req.clear(); - req.reserve(node.source->num_outputs()); - for (size_t j = 0; j < node.source->num_outputs(); ++j) { + req.reserve(num_outputs); + for (size_t j = 0; j < num_outputs; ++j) { size_t eid = idx.entry_id(i, j); ndoutputs.emplace_back(arrays[eid]); req.push_back(array_reqs[eid]); CHECK(!ndoutputs.back()->is_none()); } + const Context& ctx = ndoutputs[0]->ctx(); const DispatchMode dispatch_mode = dispatch_modes[i]; if (node.source->op() == bwd_cached_op) { const auto& cached_op = dmlc::get(node.source->attrs.parsed); @@ -320,19 +325,19 @@ void Imperative::RunGraph( arg_dtypes.emplace_back(ndinputs[i]->dtype()); } states[i] = createop[node.source->op()]( - node.source->attrs, default_ctx, arg_shapes, arg_dtypes); - InvokeOp(default_ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, states[i]); + node.source->attrs, ctx, arg_shapes, arg_dtypes); + InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, states[i]); if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[i]); } else if (is_layer_backward.get(node.source->op(), false)) { nnvm::Node* fwd_node = node.source->control_deps[0].get(); auto fwd_node_id = idx.node_id(fwd_node); - InvokeOp(default_ctx, node.source->attrs, ndinputs, ndoutputs, + InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, states[fwd_node_id]); if (recording) { RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[fwd_node_id]); } } else { - InvokeOp(default_ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode); + InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode); if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs); } @@ -378,6 +383,7 @@ std::vector Imperative::Backward( for (size_t i = 0; i < outputs.size(); ++i) { ograd_entries.emplace_back(NodeEntry{Node::Create(), 0, 0}); AGInfo& info = AGInfo::Create(ograd_entries.back().node); + info.ctx = outputs[i]->ctx(); if (ograds[i] != nullptr) { info.outputs.emplace_back(*ograds[i]); } else { @@ -495,7 +501,7 @@ std::vector Imperative::Backward( } // Assign context - Context default_ctx = outputs[0]->ctx(); + auto vctx = PlaceDevice(idx); // Infer shape type { @@ -518,9 +524,11 @@ std::vector Imperative::Backward( StorageTypeVector stypes; stypes.reserve(idx.num_node_entries()); for (const auto& i : arrays) stypes.emplace_back(i->storage_type()); - CheckAndInferStorageType( - &graph, default_ctx.dev_mask(), std::move(stypes), false, - node_range, entry_range); + exec::DevMaskVector dev_mask; + dev_mask.reserve(idx.num_nodes()); + for (const auto& i : vctx) dev_mask.emplace_back(i.dev_mask()); + CheckAndInferStorageType(&graph, std::move(dev_mask), std::move(stypes), false, + node_range, entry_range); } // Calculate ref count @@ -544,13 +552,18 @@ std::vector Imperative::Backward( const auto& dtypes = graph.GetAttr("dtype"); const auto& stypes = graph.GetAttr("storage_type"); const auto& dispatch_modes = graph.GetAttr("dispatch_mode"); - for (size_t i = num_forward_entries; i < arrays.size(); ++i) { - if (!arrays[i]->is_none()) continue; - if (stypes[i] == kDefaultStorage) { - *arrays[i] = NDArray(shapes[i], default_ctx, true, dtypes[i]); - } else { - *arrays[i] = NDArray(static_cast(stypes[i]), - shapes[i], default_ctx, true, dtypes[i]); + + for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) { + auto num_outputs = idx[i].source->num_outputs(); + for (size_t j = 0; j < num_outputs; ++j) { + auto eid = idx.entry_id(i, j); + if (!arrays[eid]->is_none()) continue; + if (stypes[eid] == kDefaultStorage) { + *arrays[eid] = NDArray(shapes[eid], vctx[i], true, dtypes[eid]); + } else { + *arrays[eid] = NDArray(static_cast(stypes[eid]), + shapes[eid], vctx[i], true, dtypes[eid]); + } } } @@ -559,7 +572,7 @@ std::vector Imperative::Backward( bool prev_recording = set_is_recording(create_graph); bool prev_training = set_is_training(is_train); - RunGraph(retain_graph, default_ctx, idx, arrays, num_forward_nodes, idx.num_nodes(), + RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(), std::move(array_reqs), std::move(ref_count), &states, dispatch_modes); set_is_recording(prev_recording); diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index c5e63ad1..7925b7b9 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -311,23 +311,33 @@ inline void PushFComputeEx(const FComputeEx& fn, const std::vector& p_inputs, const std::vector& p_outputs, const std::vector& req) { + static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); + bool is_train = Imperative::Get()->is_training(); + ExecType exec_type = ExecType::kSync; + if (fexec_type.count(op)) { + exec_type = fexec_type[op](attrs); + } std::vector inputs, outputs; DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs); - Engine::Get()->PushAsync([ctx, is_train, attrs, fn, inputs, outputs, requested, req]( + const auto& run = [ctx, exec_type, is_train, attrs, fn, inputs, outputs, requested, req]( RunContext rctx, engine::CallbackOnComplete on_complete) { - std::vector input_blobs, output_blobs; - OpContext opctx{is_train, rctx, - engine::CallbackOnComplete(), - requested}; + OpContext opctx{is_train, rctx, on_complete, requested}; fn(attrs, opctx, inputs, req, outputs); - if (ctx.dev_mask() == gpu::kDevMask) { - rctx.get_stream()->Wait(); + if (exec_type == ExecType::kSync) { + if (rctx.get_ctx().dev_mask() == gpu::kDevMask) { + rctx.get_stream()->Wait(); + } + on_complete(); } - on_complete(); - }, ctx, read_vars, write_vars, FnProperty::kNormal, - 0, PROFILER_MESSAGE(op->name.c_str())); + }; + if (exec_type == ExecType::kLocal) { + run(RunContext{ctx, nullptr}, engine::CallbackOnComplete()); + } else { + Engine::Get()->PushAsync(run, ctx, read_vars, write_vars, FnProperty::kNormal, + 0, PROFILER_MESSAGE(op->name.c_str())); + } } inline void PushOperator(const OpStatePtr& state, @@ -500,20 +510,16 @@ inline bool CheckAndInferType(nnvm::Graph* p_g, nnvm::DTypeVector&& dtypes, return false; } -inline bool CheckAndInferStorageType(nnvm::Graph* p_g, const int dev_mask, +inline bool CheckAndInferStorageType(nnvm::Graph* p_g, exec::DevMaskVector&& dev_mask, StorageTypeVector&& storage_types, bool use_inputs, std::pair node_range = {0, 0}, std::pair entry_range = {0, 0}) { using namespace nnvm; nnvm::Graph& g = *p_g; - bool dev_match = false; - if (g.attrs.count("dev_mask")) { - const auto& prev_vdev = g.GetAttr("dev_mask"); - if (prev_vdev.size() && prev_vdev[0] == dev_mask) dev_match = true; - } + bool dev_match = g.attrs.count("dev_mask") && + g.GetAttr("dev_mask") == dev_mask; if (!dev_match) { - exec::DevMaskVector vdev(g.indexed_graph().num_nodes(), dev_mask); - g.attrs["dev_mask"] = std::make_shared(std::move(vdev)); + g.attrs["dev_mask"] = std::make_shared(std::move(dev_mask)); } if (dev_match && use_inputs) { @@ -580,6 +586,54 @@ inline bool CheckAndInferStorageType(nnvm::Graph* p_g, const int dev_mask, return false; } + +inline std::vector PlaceDevice(const nnvm::IndexedGraph& idx) { + static const auto& _copyto = Op::Get("_copyto"); + + std::vector vctx( + idx.num_nodes(), Context::Create(static_cast(-1), 0)); + size_t num_unknown = idx.num_nodes(); + // forward pass + for (size_t i = 0; i < idx.num_nodes(); ++i) { + if (!idx[i].source->info.empty()) { + vctx[i] = dmlc::get(idx[i].source->info).ctx; + --num_unknown; + } else if (idx[i].source->op() == _copyto) { + CHECK_GT(idx[i].source->control_deps.size(), 0); + auto fwd_nid = idx.node_id(idx[i].source->control_deps[0].get()); + CHECK_EQ(idx[fwd_nid].source->op(), _copyto); + vctx[i] = vctx[idx[fwd_nid].inputs[0].node_id]; + --num_unknown; + } else if (idx[i].inputs.size()) { + vctx[i] = vctx[idx[i].inputs[0].node_id]; + --num_unknown; + } + } + // backward pass + for (int i = idx.num_nodes() - 1; i >= 0; --i) { + if (vctx[i].dev_type == -1) continue; + if (idx[i].source->op() == _copyto) { + auto in_nid = idx[i].inputs[0].node_id; + if (vctx[in_nid].dev_type != -1) continue; + CHECK_GT(idx[i].source->control_deps.size(), 0); + auto fwd_nid = idx.node_id(idx[i].source->control_deps[0].get()); + CHECK_EQ(idx[fwd_nid].source->op(), _copyto); + vctx[in_nid] = vctx[fwd_nid]; + --num_unknown; + continue; + } + for (const auto& j : idx[i].inputs) { + if (vctx[j.node_id].dev_type != -1) continue; + vctx[j.node_id] = vctx[i]; + --num_unknown; + } + } + CHECK_EQ(num_unknown, 0) << "Unabled to decied context for nodes"; + + return vctx; +} + + inline MemoryPlanVector PlanMemory( nnvm::Graph* p_g, nnvm::StorageVector&& storage, const std::vector& ref_count, diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index f5b63343..5f3ae4e0 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -264,7 +264,7 @@ class CommCPU : public Comm { CHECK_EQ(indices.dtype(), dst->aux_type(rowsparse::kIdx)) << "CopyRetainedRowsToGPU only supports same data type for idx array and dst aux_data(0)"; if (!src.storage_initialized() || indices.data().Size() == 0U) { - op::FillZerosRspImpl(gpu_stream, dst); + op::FillZerosRspImpl(gpu_stream, *dst); return; } using namespace mshadow; diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 44652061..6dce19da 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -396,9 +396,9 @@ size_t num_aux_data(NDArrayStorageType stype) { // Make a copy of a CSR NDArray template -inline void CopyFromToCsrImpl(const NDArray from, NDArray *to, RunContext ctx) { +inline void CopyFromToCsrImpl(const NDArray& from, const NDArray& to, RunContext ctx) { using namespace mshadow; - CHECK_EQ(from.storage_type(), to->storage_type()) << "Copying with different storage type"; + CHECK_EQ(from.storage_type(), to.storage_type()) << "Copying with different storage type"; // if source storage is not initialized, fill destination with zeros auto s = ctx.get_stream(); if (!from.storage_initialized()) { @@ -406,25 +406,25 @@ inline void CopyFromToCsrImpl(const NDArray from, NDArray *to, RunContext ctx) { return; } // Allocate storage - to->CheckAndAllocAuxData(csr::kIndPtr, from.aux_shape(csr::kIndPtr)); - to->CheckAndAllocAuxData(csr::kIdx, from.aux_shape(csr::kIdx)); - to->CheckAndAllocData(from.aux_shape(csr::kIdx)); - TBlob val = to->data(); - TBlob indptr = to->aux_data(csr::kIndPtr); - TBlob idx = to->aux_data(csr::kIdx); + to.CheckAndAllocAuxData(csr::kIndPtr, from.aux_shape(csr::kIndPtr)); + to.CheckAndAllocAuxData(csr::kIdx, from.aux_shape(csr::kIdx)); + to.CheckAndAllocData(from.aux_shape(csr::kIdx)); + TBlob val = to.data(); + TBlob indptr = to.aux_data(csr::kIndPtr); + TBlob idx = to.aux_data(csr::kIdx); ndarray::Copy(from.data(), &val, - from.ctx(), to->ctx(), ctx); + from.ctx(), to.ctx(), ctx); ndarray::Copy(from.aux_data(csr::kIndPtr), &indptr, - from.ctx(), to->ctx(), ctx); + from.ctx(), to.ctx(), ctx); ndarray::Copy(from.aux_data(csr::kIdx), &idx, - from.ctx(), to->ctx(), ctx); + from.ctx(), to.ctx(), ctx); } // Make a copy of a row-sparse NDArray template -inline void CopyFromToRspImpl(const NDArray from, NDArray *to, RunContext ctx) { +inline void CopyFromToRspImpl(const NDArray& from, const NDArray& to, RunContext ctx) { using namespace mshadow; - CHECK_EQ(from.storage_type(), to->storage_type()) << "Copying with different storage type"; + CHECK_EQ(from.storage_type(), to.storage_type()) << "Copying with different storage type"; // if source is zeros, fill destination with zeros, too auto s = ctx.get_stream(); if (!from.storage_initialized()) { @@ -432,40 +432,40 @@ inline void CopyFromToRspImpl(const NDArray from, NDArray *to, RunContext ctx) { return; } auto aux_shape = from.aux_shape(rowsparse::kIdx); - to->CheckAndAlloc({aux_shape}); - TBlob val = to->data(); - TBlob idx = to->aux_data(rowsparse::kIdx); + to.CheckAndAlloc({aux_shape}); + TBlob val = to.data(); + TBlob idx = to.aux_data(rowsparse::kIdx); ndarray::Copy(from.data(), &val, - from.ctx(), to->ctx(), ctx); + from.ctx(), to.ctx(), ctx); ndarray::Copy(from.aux_data(rowsparse::kIdx), &idx, - from.ctx(), to->ctx(), ctx); + from.ctx(), to.ctx(), ctx); } // Make a copy of a dense NDArray template -inline void CopyFromToDnsImpl(const NDArray from, NDArray *to, RunContext ctx) { +inline void CopyFromToDnsImpl(const NDArray& from, const NDArray& to, RunContext ctx) { using namespace mshadow; - CHECK_EQ(from.storage_type(), to->storage_type()) << "Copying with different storage type"; - TBlob tmp = to->data(); + CHECK_EQ(from.storage_type(), to.storage_type()) << "Copying with different storage type"; + TBlob tmp = to.data(); ndarray::Copy(from.data(), &tmp, - from.ctx(), to->ctx(), ctx); + from.ctx(), to.ctx(), ctx); } // Make a copy of an NDArray based on storage type template -void CopyFromToImpl(const NDArray from, NDArray *to, RunContext rctx) { +void CopyFromToImpl(const NDArray& from, const NDArray& to, RunContext rctx) { using namespace std; using namespace mshadow; // if storage type doesn't match, cast the storage first auto from_stype = from.storage_type(); - auto to_stype = to->storage_type(); + auto to_stype = to.storage_type(); CHECK(from_stype == kDefaultStorage || to_stype == kDefaultStorage || from_stype == to_stype) << "Copying ndarray of stype = " << from_stype << " to stype = " << to_stype << " is not supported"; const auto from_ctx = from.ctx(); - const auto to_ctx = to->ctx(); + const auto to_ctx = to.ctx(); bool is_train = Imperative::Get()->is_training(); std::vector requested; if (is_same::value && from_stype != to_stype) { @@ -478,7 +478,7 @@ void CopyFromToImpl(const NDArray from, NDArray *to, RunContext rctx) { requested}; if (from_ctx == to_ctx && from_stype != to_stype) { // same ctx, different stypes, use cast op directly without copying - common::CastStorageDispatch(opctx, from, *to); + common::CastStorageDispatch(opctx, from, to); } else { NDArray casted_nd; // an intermediate result before copying from to to if (from_stype == to_stype) { @@ -510,49 +510,44 @@ void CopyFromToImpl(const NDArray from, NDArray *to, RunContext rctx) { } } -void CopyFromTo(const NDArray &from, NDArray *to, int priority) { - if (from.var() == to->var()) { +void CopyFromTo(const NDArray& from, const NDArray& to, int priority) { + if (from.var() == to.var()) { // skip to copy to itself return; } - CHECK(from.shape() == to->shape()) + CHECK(from.shape() == to.shape()) << "operands shape mismatch" - << "from.shape = " << from.shape() << " to.shape=" << to->shape(); + << "from.shape = " << from.shape() << " to.shape=" << to.shape(); CHECK(from.shape().ndim() != 0) << "source operands have zero dimension shape"; // important: callback must always capture by value - NDArray ret = *to; int a = from.ctx().dev_mask(); - int b = to->ctx().dev_mask(); + int b = to.ctx().dev_mask(); std::vector const_vars; - if (from.var() != ret.var()) const_vars.push_back(from.var()); + if (from.var() != to.var()) const_vars.push_back(from.var()); if (a == cpu::kDevMask && b == cpu::kDevMask) { - Engine::Get()->PushSync([from, ret](RunContext ctx) { - NDArray nd(ret); - CopyFromToImpl(from, &nd, ctx); - }, from.ctx(), const_vars, {ret.var()}, + Engine::Get()->PushSync([from, to](RunContext ctx) { + CopyFromToImpl(from, to, ctx); + }, from.ctx(), const_vars, {to.var()}, FnProperty::kNormal, priority, PROFILER_MESSAGE("CopyCPU2CPU")); } else { #if MXNET_USE_CUDA if (a == cpu::kDevMask && b == gpu::kDevMask) { - Engine::Get()->PushSync([from, ret](RunContext ctx) { - NDArray nd(ret); - CopyFromToImpl(from, &nd, ctx); - }, ret.ctx(), const_vars, {ret.var()}, + Engine::Get()->PushSync([from, to](RunContext ctx) { + CopyFromToImpl(from, to, ctx); + }, to.ctx(), const_vars, {to.var()}, FnProperty::kCopyToGPU, priority, PROFILER_MESSAGE("CopyCPU2GPU")); } else if (a == gpu::kDevMask && b == cpu::kDevMask) { - Engine::Get()->PushSync([from, ret](RunContext ctx) { - NDArray nd(ret); - CopyFromToImpl(from, &nd, ctx); - }, from.ctx(), const_vars, {ret.var()}, + Engine::Get()->PushSync([from, to](RunContext ctx) { + CopyFromToImpl(from, to, ctx); + }, from.ctx(), const_vars, {to.var()}, FnProperty::kCopyFromGPU, priority, PROFILER_MESSAGE("CopyGPU2CPU")); } else if (a == gpu::kDevMask && b == gpu::kDevMask) { - Engine::Get()->PushSync([from, ret](RunContext ctx) { - NDArray nd(ret); - CopyFromToImpl(from, &nd, ctx); - }, from.ctx(), const_vars, {ret.var()}, - from.dtype() != ret.dtype() ? FnProperty::kNormal : FnProperty::kCopyFromGPU, + Engine::Get()->PushSync([from, to](RunContext ctx) { + CopyFromToImpl(from, to, ctx); + }, from.ctx(), const_vars, {to.var()}, + from.dtype() != to.dtype() ? FnProperty::kNormal : FnProperty::kCopyFromGPU, priority, PROFILER_MESSAGE("CopyGPU2GPU")); } else { LOG(FATAL) << "unknown device mask"; @@ -563,6 +558,11 @@ void CopyFromTo(const NDArray &from, NDArray *to, int priority) { } } + +void CopyFromTo(const NDArray& from, NDArray *to, int priority) { + CopyFromTo(from, *to, priority); +} + void ElementwiseSum(const std::vector &source, NDArray *out, int priority) { std::vector const_vars; const_vars.reserve(source.size()); @@ -651,10 +651,6 @@ void ClipOp(const NDArray &src, } } -inline void CopyFromToSimple(const NDArray &from, NDArray *to) { - CopyFromTo(from, to, 0); -} - template void SampleOP(const real_t &a, const real_t &b, @@ -1064,7 +1060,7 @@ NDArray NDArray::Copy(Context ctx) const { LOG(FATAL) << "NDArray::Copy cannot copy undefined storage-type ndarray to ctx.dev_type=" << ctx.dev_type << ", ctx.dev_id=" << ctx.dev_id; } - CopyFromTo(*this, &ret); + CopyFromTo(*this, ret); return ret; } @@ -1243,12 +1239,46 @@ MXNET_REGISTER_NDARRAY_FUN(fill_element_0index) // register API function // those with underscore will be registered at NDArray +void CopyFromToSimple( + const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CopyFromTo(inputs[0], outputs[0], 0); +} // copy function is special // that we need to remove kAcceptEmptyMutateTarget from it -MXNET_REGISTER_NDARRAY_FUN(_copyto) -.set_function(CopyFromToSimple) -.set_type_mask(kNDArrayArgBeforeScalar); +NNVM_REGISTER_OP(_copyto) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FInferShape", op::ElemwiseShape<1, 1>) +.set_attr("FInferType", + [](const NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { + return !op::type_is_none((*in_type)[0]) && !op::type_is_none((*out_type)[0]); + }) +.set_attr("FInferStorageType", + [](const NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + op::dispatch_mode_assign(dispatch_mode, DispatchMode::kFComputeEx); + if (op::storage_type_is_none((*out_attrs)[0])) { + (*out_attrs)[0] = (*in_attrs)[0]; + } + return true; + }) +.set_attr("FExecType", [](const NodeAttrs& attrs) { + return ExecType::kLocal; + }) +.set_attr("FGradient", op::ElemwiseGradUseNone{"_copyto"}) +.set_attr("TIsBackward", true) +.set_attr("FComputeEx", CopyFromToSimple) +.set_attr("FComputeEx", CopyFromToSimple) +.add_argument("data", "NDArray", "input data"); + void Imdecode(NDArray *ret, NDArray mean, size_t index, size_t x0, size_t y0, size_t x1, size_t y1, size_t n_channels, diff --git a/src/ndarray/ndarray_function.cu b/src/ndarray/ndarray_function.cu index 46f47ed2..62ea7de4 100644 --- a/src/ndarray/ndarray_function.cu +++ b/src/ndarray/ndarray_function.cu @@ -115,7 +115,7 @@ void ElementwiseSumRspImpl(mshadow::Stream* s, } } if (init == 0) { - FillZerosRspImpl(s, out); + FillZerosRspImpl(s, *out); return; } const dim_t num_rows = out->shape()[0]; diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index fcb46407..cff963a7 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -112,6 +112,11 @@ inline bool type_is_none(const int& x) { return x == -1; } +/*! \brief check if type is none (-1) */ +inline bool storage_type_is_none(const int& x) { + return x == -1; +} + /*! \brief check if shape is scalar({1}). */ inline bool shape_is_scalar(const TShape& x) { return x.ndim() == 1 && x.Size() == 1; diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index eadf58a3..477f8d9c 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -222,22 +222,22 @@ void PopulateFullIdxRspImpl(mshadow::Stream *s, NDArray *dst) { // Fill a rsp NDArray with zeros by updating the aux shape. template -void FillZerosRspImpl(mshadow::Stream *s, NDArray *dst) { - if (!dst->storage_initialized()) return; +void FillZerosRspImpl(mshadow::Stream *s, const NDArray& dst) { + if (!dst.storage_initialized()) return; // reset the shapes if it's not zeros - auto storage_shape = dst->storage_shape(); + auto storage_shape = dst.storage_shape(); storage_shape[0] = 0; - dst->set_aux_shape(rowsparse::kIdx, TShape(mshadow::Shape1(0))); + dst.set_aux_shape(rowsparse::kIdx, TShape(mshadow::Shape1(0))); } // Fill a CSR NDArray with zeros by updating the aux shape. template -void FillZerosCsrImpl(mshadow::Stream *s, NDArray *dst) { - if (!dst->storage_initialized()) return; +void FillZerosCsrImpl(mshadow::Stream *s, const NDArray& dst) { + if (!dst.storage_initialized()) return; // reset the shapes if it's not zeros TShape new_shape(mshadow::Shape1(0)); - dst->set_aux_shape(csr::kIndPtr, new_shape); - dst->set_aux_shape(csr::kIdx, new_shape); + dst.set_aux_shape(csr::kIndPtr, new_shape); + dst.set_aux_shape(csr::kIdx, new_shape); } template @@ -255,10 +255,10 @@ void FillComputeZerosEx(const nnvm::NodeAttrs& attrs, CHECK_EQ(req[0], kWriteTo) << "kWriteTo is expected for FillComputeZerosEx"; if (stype == kRowSparseStorage) { NDArray nd(outputs[0]); - FillZerosRspImpl(s, &nd); + FillZerosRspImpl(s, nd); } else if (stype == kCSRStorage) { NDArray nd(outputs[0]); - FillZerosCsrImpl(s, &nd); + FillZerosCsrImpl(s, nd); } else { LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs); } diff --git a/src/operator/tensor/sparse_retain-inl.h b/src/operator/tensor/sparse_retain-inl.h index 3f458567..c7751f57 100644 --- a/src/operator/tensor/sparse_retain-inl.h +++ b/src/operator/tensor/sparse_retain-inl.h @@ -267,7 +267,7 @@ void SparseRetainOpForwardRspImpl(mshadow::Stream *s, if (!input_nd.storage_initialized() || idx_data.Size() == 0U || input_nd.shape()[0] == 0) { - FillZerosRspImpl(s, output_nd); + FillZerosRspImpl(s, *output_nd); return; } @@ -387,7 +387,7 @@ void SparseRetainOpBackwardEx(const nnvm::NodeAttrs& attrs, const TBlob idx_data = inputs[sr::kIdx].data(); if (idx_data.Size() == 0U) { NDArray output = outputs[sr::kArr]; - FillZerosRspImpl(s, &output); + FillZerosRspImpl(s, output); return; } diff --git a/src/operator/tensor/square_sum-inl.h b/src/operator/tensor/square_sum-inl.h index f5a70b8a..cd1059d4 100644 --- a/src/operator/tensor/square_sum-inl.h +++ b/src/operator/tensor/square_sum-inl.h @@ -278,7 +278,7 @@ void SquareSumRspImpl(const nnvm::NodeAttrs& attrs, Kernel::Launch(s, out_data_size, output->data().dptr()); }) } else if (output->storage_type() == kRowSparseStorage) { - FillZerosRspImpl(s, output); + FillZerosRspImpl(s, *output); } else { LOG(FATAL) << "SquareSumRspImpl only supports row-sparse/dense output storage type"; } @@ -348,7 +348,7 @@ void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs, CHECK_EQ(igrad->storage_type(), kRowSparseStorage); CHECK_EQ(req, kWriteTo); if (!input.storage_initialized()) { - FillZerosRspImpl(s, igrad); + FillZerosRspImpl(s, *igrad); return; } diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index ec658448..988b8b0b 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1419,6 +1419,32 @@ def test_cuda_rtc(): assert (y.asnumpy() == 12).all() +def test_cross_device_autograd(): + x = mx.nd.random.uniform(shape=(10,)) + x.attach_grad() + + with mx.autograd.record(): + y = mx.nd.tanh(x) + y = y.copyto(mx.gpu(0)) + y = mx.nd.tanh(y) + y = y.copyto(mx.cpu(0)) + y = mx.nd.tanh(y) + y = y.copyto(mx.gpu(0)) + y = y.copyto(mx.gpu(0)) + + y.backward() + + dx = x.grad.asnumpy() + x.grad[:] = 0 + + with mx.autograd.record(): + y = x + for i in range(3): + y = mx.nd.tanh(y) + y.backward() + + assert_almost_equal(dx, x.grad.asnumpy()) + if __name__ == '__main__': import nose nose.runmodule()