From 5431e12f11fc5446f6ec2a25098b4e3b67ee7eb3 Mon Sep 17 00:00:00 2001 From: Eric Junyuan Xie Date: Fri, 15 Jun 2018 18:22:16 -0700 Subject: [PATCH] Static alloc for hybridblock (#11313) Thanks for the contribution, this is now merged --- include/mxnet/c_api.h | 5 - include/mxnet/imperative.h | 89 --- include/mxnet/ndarray.h | 8 + include/mxnet/op_attr_types.h | 33 +- python/mxnet/_ctypes/ndarray.py | 16 +- python/mxnet/gluon/block.py | 74 ++- src/c_api/c_api_ndarray.cc | 26 +- src/engine/threaded_engine.cc | 3 +- src/executor/attach_op_execs_pass.cc | 165 +++--- src/executor/attach_op_resource_pass.cc | 16 +- src/executor/exec_pass.h | 28 +- src/executor/graph_executor.cc | 2 +- src/imperative/cached_op.cc | 750 +++++++++++++++++++----- src/imperative/cached_op.h | 174 ++++++ src/imperative/imperative.cc | 90 +-- src/imperative/imperative_utils.cc | 120 ++++ src/imperative/imperative_utils.h | 256 ++++++-- tests/python/unittest/test_gluon.py | 67 ++- 18 files changed, 1399 insertions(+), 523 deletions(-) create mode 100644 src/imperative/cached_op.h create mode 100644 src/imperative/imperative_utils.cc diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 55c26bc980b2..4dd858a51c4b 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -987,11 +987,6 @@ MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle, int num_flags, const char** keys, const char** vals, - int num_inputs, - const char** input_names, - int num_params, - const char** param_names, - NDArrayHandle* params, CachedOpHandle *out); /*! * \brief free cached operator diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index 758ce8513213..7ea60df33028 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -35,23 +35,6 @@ #include "./ndarray.h" namespace mxnet { -/*! \brief CachedOp Parameters */ -struct CachedOpConfig : public dmlc::Parameter { - uint32_t inline_limit; - uint32_t forward_bulk_size; - uint32_t backward_bulk_size; - DMLC_DECLARE_PARAMETER(CachedOpConfig) { - DMLC_DECLARE_FIELD(inline_limit) - .set_default(2) - .describe("Maximum number of operators that can be inlined."); - DMLC_DECLARE_FIELD(forward_bulk_size) - .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15)) - .describe("Segment size of bulk execution during forward pass."); - DMLC_DECLARE_FIELD(backward_bulk_size) - .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15)) - .describe("Segment size of bulk execution during backward pass."); - } -}; /*! \brief runtime functions for NDArray */ class Imperative { public: @@ -94,67 +77,6 @@ class Imperative { && info.out_grads.size() == 1; } }; - class CachedOp { - public: - CachedOp( - const nnvm::Symbol& sym, - const std::vector >& flags, - const std::vector arg_names, - const std::unordered_map >& params); - uint32_t num_inputs() { - return fwd_graph_.indexed_graph().input_nodes().size(); - } - uint32_t num_outputs() { - return fwd_graph_.outputs.size(); - } - uint32_t num_backward_inputs() { - return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size(); - } - std::vector& save_inputs() { - return save_inputs_; - } - std::vector& save_outputs() { - return save_outputs_; - } - const std::unordered_set& mutable_input_nodes() { - return fwd_graph_.indexed_graph().mutable_input_nodes(); - } - nnvm::Graph GetForwardGraph(const bool recording, - const std::vector& inputs); - nnvm::Graph GetBackwardGraph(const OpStatePtr& state, - const std::vector& reqs, - const std::vector& inputs); - std::vector Gradient(const nnvm::NodePtr& node, - const std::vector& ograds); - void Forward(const std::shared_ptr& op_ptr, - const std::vector& args, - const std::vector& outputs); - void Backward(const bool retain_graph, - const OpStatePtr& state, - const std::vector& inputs, - const std::vector& reqs, - const std::vector& outputs); - - private: - struct CachedOpState { - std::vector buff; - std::vector states; - }; - std::mutex mutex_; - CachedOpConfig config_; - nnvm::Graph fwd_graph_; - nnvm::Graph grad_graph_; - nnvm::Graph full_graph_; - std::unordered_map > params_; - bool inlining_; - std::vector ograd_entries_; - std::vector curr_grad_req_; - std::vector bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_; - std::vector fwd_args_idx_; - std::vector fwd_params_idx_; - std::vector bwd_input_eid_; - std::vector save_inputs_, save_outputs_; - }; /*! \brief whether operator recording is on. */ bool is_training() const { return is_train_; @@ -222,15 +144,6 @@ class Imperative { uint32_t num_inputs, uint32_t num_outputs, std::vector *p_save_inputs, std::vector *p_save_outputs); - void RunGraph( - const bool retain_graph, - const nnvm::IndexedGraph& idx, - const std::vector arrays, - size_t node_start, size_t node_end, - std::vector&& array_reqs, - std::vector&& ref_count, - std::vector *p_states, - const DispatchModeVector& dispatch_modes); /*! \brief indicate whether is training. */ #if DMLC_CXX11_THREAD_LOCAL static thread_local bool is_train_; @@ -247,7 +160,5 @@ class Imperative { int backward_bulk_size_{0}; }; -using CachedOpPtr = std::shared_ptr; - } // namespace mxnet #endif // MXNET_IMPERATIVE_H_ diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index e243eb71c477..ae96fd87b0db 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -155,6 +155,14 @@ class NDArray { return byte_offset_ > 0 || shape() != ptr_->storage_shape; } + /* \brief Check whether the two arrays are the same array */ + inline bool IsSame(const NDArray& other) { + return ptr_ == other.ptr_ && + shape_ == other.shape_ && + byte_offset_ == other.byte_offset_ && + dtype_ == other.dtype_; + } + /*! * \return the shape of current NDArray. */ diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index 3969d8445be1..f4694efad297 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -126,25 +126,36 @@ class OpStatePtr { template static OpStatePtr Create(Args&&... args) { OpStatePtr ret; - ret.ptr_ = std::make_shared(); - ret.ptr_->var_ = Engine::Get()->NewVariable(); - ret.ptr_->state_.construct(std::forward(args)...); + auto state = new T(std::forward(args)...); + auto var = Engine::Get()->NewVariable(); + ret.ptr_.reset( + new OpState(var, state), + [](OpState* p) { + Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), p->var); + delete reinterpret_cast(p->state); + delete p; + }); return ret; } /* \brief Get engine variable associated with this state */ engine::VarHandle get_var() const { - return ptr_->var_; + return ptr_->var; } /* \brief Get state of type T */ template T& get_state() const { - return dmlc::get(ptr_->state_); + return *reinterpret_cast(ptr_->state); } /* \brief clear state */ void reset() { ptr_.reset(); } + /* \brief checks whether the managed object is managed only by the current + OpStatePtr instance */ + bool unique() const { + return ptr_.unique(); + } /* \brief Whether state is empty */ explicit operator bool() const { return ptr_ ? true : false; @@ -153,16 +164,12 @@ class OpStatePtr { private: /* \brief state structure */ struct OpState { - OpState() {} + engine::VarHandle var; + void* state; + + OpState(engine::VarHandle var_, void* state_) : var(var_), state(state_) {} OpState(const OpState& other) = delete; OpState& operator=(const OpState& other) = delete; - - ~OpState() { - Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), var_); - } - - engine::VarHandle var_; - dmlc::any state_; }; /* \brief shared pointer to state */ std::shared_ptr ptr_; diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py index d2cae0c45aaa..f324545a2352 100644 --- a/python/mxnet/_ctypes/ndarray.py +++ b/python/mxnet/_ctypes/ndarray.py @@ -105,28 +105,14 @@ def _imperative_invoke(handle, ndargs, keys, vals, out): class CachedOp(object): """Cached operator handle.""" __slots__ = ["handle"] - def __init__(self, sym, flags=(), inputs=None, params=None): + def __init__(self, sym, flags=()): self.handle = CachedOpHandle() - param_names = [] - param_arrays = [] - if inputs is None: - assert params is None, "When inputs is None params must also be None." - inputs = sym.list_inputs() - elif params is not None: - for name, arrs in params.items(): - param_arrays.extend(arrs) - param_names.extend([name] * len(arrs)) check_call(_LIB.MXCreateCachedOpEx( sym.handle, len(flags), c_str_array([key for key, _ in flags]), c_str_array([str(val) for _, val in flags]), - len(inputs), - c_str_array(inputs), - len(param_names), - c_str_array(param_names), - c_handle_array(param_arrays), ctypes.byref(self.handle))) def __del__(self): diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 3b97c0578cae..293fafab487b 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -502,8 +502,16 @@ def hybridize(self, active=True, **kwargs): ---------- active : bool, default True Whether to turn hybrid on or off. - **kwargs : string - Additional flags for hybridized operator. + static_alloc : bool, default False + Statically allocate memory to improve speed. Memory usage may increase. + static_shape : bool, default False + Optimize for invariant input shapes between iterations. Must also + set static_alloc to True. Change of input shapes is still allowed + but slower. + forward_bulk_size : int, default 15 + Segment size of bulk execution during forward pass. + backward_bulk_size : int, default 15 + Segment size of bulk execution during backward pass. """ for cld in self._children.values(): cld.hybridize(active, **kwargs) @@ -696,7 +704,7 @@ def __init__(self, prefix=None, params=None): self._out_format = None self._in_format = None self._active = False - self._flags = {} + self._flags = [] def __setattr__(self, name, value): """Registers parameters.""" @@ -723,39 +731,43 @@ def _get_graph(self, *args): return self._cached_graph def _build_cache(self, *args): - inputs, out = self._get_graph(*args) - input_names = [i.name for i in inputs] - + data, out = self._get_graph(*args) + data_names = {data.name : i for i, data in enumerate(data)} params = self.collect_params() + input_names = out.list_inputs() + param_names = set(params.keys()) - expected_names = set(out.list_inputs()) + expected_names = set(input_names) for name in expected_names: - assert name in param_names or name in input_names, \ + assert name in param_names or name in data_names, \ "Unknown input to HybridBlock: %s"%name - used_input_names = [i for i in input_names if i in expected_names] - if len(used_input_names) != len(input_names): - unused = ', '.join(['%d-th'%i for i, name in enumerate(input_names) + used_data_names = [i for i in data_names if i in expected_names] + if len(used_data_names) != len(data_names): + unused = ', '.join(['%d-th'%i for name, i in data_names.items() if name not in expected_names]) warnings.warn("The %s input to HybridBlock is not used by any " "computation. Is this intended?"%unused, stacklevel=4) - used_param_names = set(i for i in param_names if i in expected_names) + used_param_names = [i for i in param_names if i in expected_names] if len(used_param_names) != len(param_names): - unused = ', '.join(list(param_names - used_param_names)) + unused = ', '.join(list(param_names - set(used_param_names))) warnings.warn("Parameter %s is not used by any computation. " "Is this intended?"%unused, stacklevel=4) - used_params = {k: params[k] for k in used_param_names} - try: - param_dict = {k: v.list_data() for k, v in used_params.items()} - except DeferredInitializationError: - self._deferred_infer_shape(*args) - for i in used_params.values(): - i._finish_deferred_init() - param_dict = {k: v.list_data() for k, v in used_params.items()} - - self._cached_op = ndarray.CachedOp(out, self._flags, input_names, param_dict) + data_indices = [] + param_indices = [] + self._cached_op_args = [] + for i, name in enumerate(input_names): + if name in data_names: + data_indices.append(i) + self._cached_op_args.append((True, data_names[name])) + else: + param_indices.append(i) + self._cached_op_args.append((False, params[name])) + flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ + self._flags + self._cached_op = ndarray.CachedOp(out, flags) def _deferred_infer_shape(self, *args): try: @@ -771,7 +783,19 @@ def _call_cached_op(self, *args): args, fmt = _flatten(args, "input") assert fmt == self._in_format, "Invalid input format" - out = self._cached_op(*args) + try: + cargs = [args[i] if is_arg else i.data() + for is_arg, i in self._cached_op_args] + except DeferredInitializationError: + self._deferred_infer_shape(*args) + cargs = [] + for is_arg, i in self._cached_op_args: + if is_arg: + cargs.append(args[i]) + else: + i._finish_deferred_init() + cargs.append(i.data()) + out = self._cached_op(*cargs) if isinstance(out, NDArray): out = [out] return _regroup(out, self._out_format)[0] @@ -792,7 +816,7 @@ def register_child(self, block, name=None): def hybridize(self, active=True, **kwargs): self._active = active - self._flags = kwargs.items() + self._flags = list(kwargs.items()) self._clear_cached_op() if active and self._forward_hooks or self._forward_pre_hooks: warnings.warn('"{}" is being hybridized while still having forward hook/pre-hook. ' diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 9aabe04656e5..34bd4b20aa54 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -36,6 +36,7 @@ #include "../common/utils.h" #include "../common/exec_utils.h" #include "../imperative/imperative_utils.h" +#include "../imperative/cached_op.h" using namespace mxnet; @@ -160,12 +161,8 @@ int MXCreateCachedOp(SymbolHandle handle, std::vector input_names; input_names.reserve(inputs.size()); for (const auto& i : inputs) input_names.push_back(i->attrs.name); - *out = new std::shared_ptr( - new Imperative::CachedOp( - *sym, - std::vector >(), - input_names, - std::unordered_map >())); + *out = new CachedOpPtr(new CachedOp( + *sym, std::vector >())); API_END(); } @@ -173,11 +170,6 @@ int MXCreateCachedOpEx(SymbolHandle handle, int num_flags, const char** keys, const char** vals, - int num_args, - const char** arg_names, - int num_params, - const char** param_names, - NDArrayHandle* params, CachedOpHandle *out) { nnvm::Symbol* sym = static_cast(handle); @@ -186,17 +178,7 @@ int MXCreateCachedOpEx(SymbolHandle handle, for (int i = 0; i < num_flags; ++i) { flags.push_back({keys[i], vals[i]}); } - std::vector args; - for (int i = 0; i < num_args; ++i) { - args.push_back(arg_names[i]); - } - std::unordered_map > param_dict; - for (int i = 0; i < num_params; ++i) { - param_dict[param_names[i]].emplace_back( - *reinterpret_cast(params[i])); - } - *out = new std::shared_ptr( - new Imperative::CachedOp(*sym, flags, args, param_dict)); + *out = new CachedOpPtr(new CachedOp(*sym, flags)); API_END(); } diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index dc0436e02a8e..e70cc197c0c3 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -278,6 +278,8 @@ void ThreadedEngine::DeleteOperator(OprHandle op) { } void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority, bool profiling) { + BulkFlush(); + ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op); OprBlock* opr_block = OprBlock::New(); opr_block->opr = threaded_opr; @@ -323,7 +325,6 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx, << device_count_; } #endif - BulkFlush(); ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name, wait); opr->temporary = true; const bool profiling = profiler_->IsProfiling(profiler::Profiler::kImperative); diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index 697e4869a049..72919d90c620 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -134,6 +134,10 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor { return state_.get_var(); } + OpStatePtr state() const override { + return state_; + } + explicit StatefulComputeExecutor(const OpStatePtr& state, const FStatefulCompute& fcompute, ExecType exec_type, @@ -142,7 +146,6 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor { state_(state), fcompute_(fcompute), exec_type_(exec_type) {} private: - friend Graph AttachOpExecs(Graph g); OpStatePtr state_; FStatefulCompute fcompute_; ExecType exec_type_; @@ -170,13 +173,16 @@ class StatefulComputeExExecutor : public OpExecutor { return state_.get_var(); } + OpStatePtr state() const override { + return state_; + } + explicit StatefulComputeExExecutor(const OpStatePtr& state, const FStatefulComputeEx& fcompute, ExecType exec_type) : state_(state), fcompute_(fcompute), exec_type_(exec_type) {} private: - friend Graph AttachOpExecs(Graph g); OpStatePtr state_; FStatefulComputeEx fcompute_; ExecType exec_type_; @@ -241,16 +247,15 @@ class FComputeExExecutor : public OpExecutor { ExecType exec_type_; }; -// pass to attach operator executors -Graph AttachOpExecs(Graph g) { +void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) { using nnvm::DTypeVector; using nnvm::ShapeVector; using nnvm::FMutateInputs; - auto& fcreate_op_state = nnvm::Op::GetAttr("FCreateOpState"); - auto& fmutate_inputs = nnvm::Op::GetAttr("FMutateInputs"); - auto& fexec_type = nnvm::Op::GetAttr("FExecType"); - auto& is_layer_backward = nnvm::Op::GetAttr("TIsLayerOpBackward"); + static auto& fcreate_op_state = nnvm::Op::GetAttr("FCreateOpState"); + static auto& fmutate_inputs = nnvm::Op::GetAttr("FMutateInputs"); + static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); + static auto& is_layer_backward = nnvm::Op::GetAttr("TIsLayerOpBackward"); const auto& vdtype = g.GetAttr("dtype"); const auto& vshape = g.GetAttr("shape"); @@ -259,81 +264,87 @@ Graph AttachOpExecs(Graph g) { // get the graph const auto& idx = g.indexed_graph(); - std::vector > ret(idx.num_nodes()); + OpExecVector& ret = *p_ret; // initialize the nodes - for (size_t i = 0; i < idx.num_nodes(); ++i) { - const auto& inode = idx[i]; - if (inode.source->is_variable()) continue; - const nnvm::Op *op = inode.source->op(); - ExecType exec_type = ExecType::kSync; - std::vector mutate_index; - if (fmutate_inputs.count(op)) { - mutate_index = fmutate_inputs[op](inode.source->attrs); - } - if (fexec_type.count(op)) { - exec_type = fexec_type[op](inode.source->attrs); + const auto& inode = idx[i]; + if (inode.source->is_variable()) return; + const nnvm::Op *op = inode.source->op(); + ExecType exec_type = ExecType::kSync; + std::vector mutate_index; + if (fmutate_inputs.count(op)) { + mutate_index = fmutate_inputs[op](inode.source->attrs); + } + if (fexec_type.count(op)) { + exec_type = fexec_type[op](inode.source->attrs); + } + CHECK(dispatch_modes[i] != DispatchMode::kUndefined); + if (fcreate_op_state.count(op)) { + std::vector ishape; + std::vector itype; + for (const auto& e : inode.inputs) { + ishape.emplace_back(vshape[idx.entry_id(e)]); + itype.emplace_back(vdtype[idx.entry_id(e)]); } - CHECK(dispatch_modes[i] != DispatchMode::kUndefined); - if (fcreate_op_state.count(op)) { - std::vector ishape; - std::vector itype; - for (const auto& e : inode.inputs) { - ishape.emplace_back(vshape[idx.entry_id(e)]); - itype.emplace_back(vdtype[idx.entry_id(e)]); - } - OpStatePtr state = fcreate_op_state[op]( - inode.source->attrs, vctx[i], ishape, itype); - FStatefulComputeEx fcompute_ex = common::GetFCompute( - op, "FStatefulComputeEx", vctx[i]); - // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx - if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { - ret[i] = std::make_shared(state, fcompute_ex, exec_type); - } else { - FStatefulCompute fcompute = common::GetFCompute( - op, "FStatefulCompute", vctx[i]); - CHECK(fcompute != nullptr) - << "One of FStatefulCompute and FStatefulComputeEx must be registered " - << "for stateful operator " << op->name; - ret[i] = std::make_shared(state, fcompute, - exec_type, mutate_index); - } - } else if (is_layer_backward.get(op, false)) { - CHECK_GE(inode.control_deps.size(), 1); - uint32_t fwd_id = inode.control_deps[0]; - CHECK(vctx[fwd_id] == vctx[i]); - CHECK(ret[fwd_id] != nullptr); - FStatefulComputeEx fcompute_ex = common::GetFCompute( - op, "FStatefulComputeEx", vctx[i]); - // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx - if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { - ret[i] = std::make_shared( - dynamic_cast(ret[fwd_id].get())->state_, - fcompute_ex, exec_type); - } else { - FStatefulCompute fcompute = common::GetFCompute( - op, "FStatefulCompute", vctx[i]); - CHECK(fcompute != nullptr) - << "One of FStatefulCompute and FStatefulComputeEx must be registered " - << "for stateful operator " << op->name; - ret[i] = std::make_shared( - dynamic_cast(ret[fwd_id].get())->state_, - fcompute, exec_type, mutate_index); - } + OpStatePtr state = fcreate_op_state[op]( + inode.source->attrs, vctx[i], ishape, itype); + FStatefulComputeEx fcompute_ex = common::GetFCompute( + op, "FStatefulComputeEx", vctx[i]); + // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx + if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { + ret[i] = std::make_shared(state, fcompute_ex, exec_type); } else { - FCompute fcompute = common::GetFCompute(op, "FCompute", vctx[i]); - FComputeEx fcomp_ex = common::GetFCompute(op, "FComputeEx", vctx[i]); - if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { - ret[i] = std::make_shared( - inode.source->attrs, fcomp_ex, exec_type); - } else if (fcompute != nullptr) { - ret[i] = std::make_shared( - inode.source->attrs, fcompute, exec_type, mutate_index); - } else { - LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name; - } + FStatefulCompute fcompute = common::GetFCompute( + op, "FStatefulCompute", vctx[i]); + CHECK(fcompute != nullptr) + << "One of FStatefulCompute and FStatefulComputeEx must be registered " + << "for stateful operator " << op->name; + ret[i] = std::make_shared(state, fcompute, + exec_type, mutate_index); + } + } else if (is_layer_backward.get(op, false)) { + CHECK_GE(inode.control_deps.size(), 1); + uint32_t fwd_id = inode.control_deps[0]; + CHECK(vctx[fwd_id] == vctx[i]); + CHECK(ret[fwd_id] != nullptr); + FStatefulComputeEx fcompute_ex = common::GetFCompute( + op, "FStatefulComputeEx", vctx[i]); + // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx + if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { + ret[i] = std::make_shared( + ret[fwd_id].get()->state(), fcompute_ex, exec_type); + } else { + FStatefulCompute fcompute = common::GetFCompute( + op, "FStatefulCompute", vctx[i]); + CHECK(fcompute != nullptr) + << "One of FStatefulCompute and FStatefulComputeEx must be registered " + << "for stateful operator " << op->name; + ret[i] = std::make_shared( + ret[fwd_id].get()->state(), fcompute, exec_type, mutate_index); } + } else { + FCompute fcompute = common::GetFCompute(op, "FCompute", vctx[i]); + FComputeEx fcomp_ex = common::GetFCompute(op, "FComputeEx", vctx[i]); + if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { + ret[i] = std::make_shared( + inode.source->attrs, fcomp_ex, exec_type); + } else if (fcompute != nullptr) { + ret[i] = std::make_shared( + inode.source->attrs, fcompute, exec_type, mutate_index); + } else { + LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name; + } + } +} + + +// pass to attach operator executors +Graph AttachOpExecs(Graph g) { + const auto& idx = g.indexed_graph(); + OpExecVector ret(idx.num_nodes()); + for (size_t i = 0; i < idx.num_nodes(); ++i) { + CreateOpExecs(g, &ret, i); } g.attrs["op_execs"] = std::make_shared(ret); return g; diff --git a/src/executor/attach_op_resource_pass.cc b/src/executor/attach_op_resource_pass.cc index 681866296e1c..56122cda6ff0 100644 --- a/src/executor/attach_op_resource_pass.cc +++ b/src/executor/attach_op_resource_pass.cc @@ -30,12 +30,15 @@ namespace mxnet { namespace exec { -Graph AttachOpResources(Graph g) { +void AttachOpResources( + const Graph& g, + const OpExecVector& op_execs, + size_t start_nid, + size_t end_nid) { static auto& fresource = nnvm::Op::GetAttr("FResourceRequest"); static auto& fresource_ex = nnvm::Op::GetAttr("FResourceRequestEx"); - auto& op_execs = nnvm::get(*g.attrs.at("op_execs")); const auto& vctx = g.GetAttr("context"); const auto& vdispatch = g.GetAttr("dispatch_mode"); const auto& dev_masks = g.GetAttr("dev_mask"); @@ -43,7 +46,7 @@ Graph AttachOpResources(Graph g) { // Use global resource pool for each executor for now. std::map cached_temp; // Resource allocation - for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { + for (uint32_t nid = start_nid; nid < end_nid; ++nid) { const auto& inode = idx[nid]; if (inode.source->is_variable()) continue; const Context &ctx = vctx[nid]; @@ -84,7 +87,12 @@ Graph AttachOpResources(Graph g) { requested.push_back(ResourceManager::Get()->Request(ctx, ResourceRequest::kTempSpace)); } } - return g; } + +void AttachOpResources(const Graph& g) { + const auto& op_execs = g.GetAttr("op_execs"); + AttachOpResources(g, op_execs, 0, g.indexed_graph().num_nodes()); +} + } // namespace exec } // namespace mxnet diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index 99b1b162eaee..26a249118940 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -82,6 +82,10 @@ class OpExecutor { virtual engine::VarHandle var() const { return nullptr; } + /*! \return return operator state */ + virtual OpStatePtr state() const { + return OpStatePtr(); + } }; /*! @@ -102,6 +106,14 @@ using ContextVector = std::vector; */ using DevMaskVector = std::vector; +/*! + * \brief create OpExecutor for a node in graph + * + * \param g input graph + * \param p_ret OpExecVector for input and output + * \param i the id of the node + */ +void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i); /*! * \brief Attach OpExecutor to the graph attributes. * @@ -115,12 +127,20 @@ Graph AttachOpExecs(Graph g); * \brief Attach Resource to the OpExecVector of the graph. * * \param g input graph need to contain op_exec attribute. + */ +void AttachOpResources(const Graph& g); +/*! + * \brief Attach Resource to the OpExecVector * - * \return graph with new attribute "op_exec" of type OpExecVector - * The fields on the OpExecVector are not yet been setup. + * \param g input graph + * \param op_execs OpExecutor vector + * \param start_nid starting node id + * \param end_nid end node id */ -Graph AttachOpResources(Graph g); - +void AttachOpResources(const Graph& g, + const OpExecVector& op_execs, + size_t start_nid, + size_t end_nid); /*! * \brief Discover chance of inplace addto operators. * i.e. z = plus(z, source_op), and encourage it to become z += source_op. diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index e28867d5488e..831b5f900237 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -912,7 +912,7 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, } g = AttachOpExecs(g); - g = AttachOpResources(g); + AttachOpResources(g); graph_ = std::move(g); if (shared_exec != nullptr) { diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 140b5a5d81e0..b17fae4b3cf3 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -19,16 +19,78 @@ #include #include #include "./imperative_utils.h" +#include "./cached_op.h" +#include "../executor/exec_pass.h" +#include "../profiler/profiler.h" + namespace mxnet { DMLC_REGISTER_PARAMETER(CachedOpConfig); -Imperative::CachedOp::CachedOp( +struct CachedOp::GraphInfo { + nnvm::Graph fwd_graph; + nnvm::Graph full_graph; + std::vector bwd_output_reqs; + std::vector bwd_input_eid; +}; + +struct CachedOp::DynamicRuntime { + GraphInfo info; + std::vector buff; + std::vector op_states; +}; + +struct CachedOp::CachedOpState { + CachedOpState(const Context& context_, + const nnvm::Graph& fwd_graph_, + const nnvm::Graph& full_graph_) { + context = context_; + info.fwd_graph = fwd_graph_; + info.full_graph = full_graph_; + + size_t max_nodes = info.full_graph.indexed_graph().num_nodes(); + size_t max_entries = info.full_graph.indexed_graph().num_node_entries(); + info.fwd_graph.attrs["context"] = std::make_shared( + std::vector(info.fwd_graph.indexed_graph().num_nodes(), context)); + info.full_graph.attrs["context"] = std::make_shared( + std::vector(max_nodes, context)); + + buff.resize(max_entries); + arrays.resize(max_entries); + array_reqs.resize(max_entries); + dynamic_entries.resize(max_entries, false); + op_states.resize(max_nodes); + execs.resize(max_nodes); + opr_segs.resize(max_nodes); + } + + std::mutex mutex; + Context context; + GraphInfo info; + + bool recording = false; + bool fwd_alloc = false; + bool bwd_alloc = false; + bool fwd_exec_init = false; + bool bwd_exec_init = false; + + std::vector buff; + std::vector arrays; + std::vector array_reqs; + + std::vector op_states; + std::vector > execs; + std::vector opr_segs; + + std::vector dynamic_entries; + std::multimap fwd_reuse_pool; + std::multimap bwd_reuse_pool; +}; + +CachedOp::CachedOp( const nnvm::Symbol& sym, - const std::vector >& flags, - const std::vector arg_names, - const std::unordered_map >& params) { + const std::vector >& flags) { using namespace nnvm; using namespace imperative; static const std::vector zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")}; @@ -68,34 +130,22 @@ Imperative::CachedOp::CachedOp( fwd_graph_.attrs["forward_ref_count"] = std::make_shared(std::move(ref_count)); - inlining_ = (idx.num_nodes() - idx.input_nodes().size()) <= config_.inline_limit; + inlining_ = !config_.static_alloc && + (idx.num_nodes() - idx.input_nodes().size()) <= config_.inline_limit; } // Set params { const auto& idx = fwd_graph_.indexed_graph(); - std::unordered_map arg_name_to_id; - for (size_t i = 0; i < idx.input_nodes().size(); ++i) { - const auto& name = idx[idx.input_nodes()[i]].source->attrs.name; - auto iter = params.find(name); - if (iter == params.end()) { - arg_name_to_id[name] = i; - continue; - } - fwd_params_idx_.push_back(i); - for (const auto& param : iter->second) { - params_[param.ctx()].emplace_back(param); + if (config_.data_indices.ndim() || config_.param_indices.ndim()) { + CHECK_EQ(config_.data_indices.ndim() + config_.param_indices.ndim(), + idx.input_nodes().size()); + } else { + std::vector tmp; + for (size_t i = 0; i < idx.input_nodes().size(); ++i) { + tmp.push_back(i); } - } - - CHECK_EQ(arg_name_to_id.size(), arg_names.size()) - << "CachedOp expects " << arg_name_to_id.size() - << " inputs, given " << arg_names.size(); - - for (const auto& name : arg_names) { - auto iter = arg_name_to_id.find(name); - CHECK(iter != arg_name_to_id.end()) << "Unexpected input name " << name; - fwd_args_idx_.push_back(iter->second); + config_.data_indices.assign(tmp.begin(), tmp.end()); } } @@ -107,9 +157,14 @@ Imperative::CachedOp::CachedOp( } std::vector xs; - std::vector args = sym.ListInputs(Symbol::kReadOnlyArgs); - xs.reserve(args.size()); - for (const auto& i : args) xs.emplace_back(NodeEntry{i, 0, 0}); + const auto& idx = fwd_graph_.indexed_graph(); + for (size_t i = 0; i < idx.input_nodes().size(); ++i) { + auto nid = idx.input_nodes()[i]; + if (idx.mutable_input_nodes().count(nid)) continue; + fwd_input_to_grad_output_[i] = xs.size(); + xs.emplace_back(NodeEntry{idx[nid].weak_ref.lock(), 0, 0}); + } + CHECK_GT(xs.size(), 0) << "There are no inputs in computation graph that require gradients."; @@ -125,7 +180,7 @@ Imperative::CachedOp::CachedOp( size_t num_forward_entries = fwd_graph_.indexed_graph().num_node_entries(); full_graph_.outputs = fwd_graph_.outputs; - curr_grad_req_ = std::vector(grad_graph_.outputs.size(), true); + bwd_output_reqs_ = std::vector(grad_graph_.outputs.size(), kWriteTo); for (const auto& i : grad_graph_.outputs) full_graph_.outputs.emplace_back(i); const auto& idx = full_graph_.indexed_graph(); @@ -169,7 +224,10 @@ Imperative::CachedOp::CachedOp( } } -std::vector Imperative::CachedOp::Gradient( +CachedOp::~CachedOp() { +} + +std::vector CachedOp::Gradient( const nnvm::NodePtr& node, const std::vector& ograds) { using namespace nnvm; @@ -206,13 +264,15 @@ std::vector Imperative::CachedOp::Gradient( return ret; } -nnvm::Graph Imperative::CachedOp::GetForwardGraph( - const bool recording, const std::vector& inputs) { + +bool CachedOp::SetForwardGraph( + GraphInfo* info, + const bool recording, + const std::vector& inputs) { using namespace nnvm; using namespace imperative; - std::lock_guard lock(mutex_); CHECK_EQ(inputs.size(), num_inputs()); - nnvm::Graph& g = fwd_graph_; + nnvm::Graph& g = info->fwd_graph; ShapeVector shape_inputs; DTypeVector dtype_inputs; @@ -237,18 +297,22 @@ nnvm::Graph Imperative::CachedOp::GetForwardGraph( g.attrs.erase("forward_mem_plan"); g.attrs.erase("full_mem_plan"); } else if (g.attrs.count(recording ? "full_mem_plan" : "forward_mem_plan")) { - return g; + return true; } const auto& idx = g.indexed_graph(); StorageVector storage(idx.num_node_entries(), exec::kBadStorageID); - for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] = exec::kExternalStorageID; const auto& stypes = g.GetAttr("storage_type"); CHECK_EQ(stypes.size(), storage.size()); for (size_t i = 0; i < stypes.size(); i++) { - if (stypes[i] != kDefaultStorage) - storage[i] = exec::kDynamicStorageID; + if (stypes[i] != kDefaultStorage) storage[i] = exec::kDynamicStorageID; + } + for (const auto i : idx.input_nodes()) { + storage[idx.entry_id(i, 0)] = exec::kExternalStorageID; + } + for (size_t i = 0; i < idx.outputs().size(); ++i) { + storage[idx.entry_id(idx.outputs()[i])] = exec::kExternalStorageID; } auto mem_plan = PlanMemory( @@ -257,51 +321,50 @@ nnvm::Graph Imperative::CachedOp::GetForwardGraph( g.attrs[recording ? "full_mem_plan" : "forward_mem_plan"] = std::make_shared(std::move(mem_plan)); - return g; + return false; } -nnvm::Graph Imperative::CachedOp::GetBackwardGraph( - const OpStatePtr& op_state, +bool CachedOp::SetBackwardGraph( + GraphInfo* info, const std::vector& reqs, - const std::vector& inputs) { + const std::vector& inputs, + bool detect_inplace_addto) { using namespace nnvm; using namespace imperative; std::lock_guard lock(mutex_); - nnvm::Graph& g = full_graph_; - auto& state = op_state.get_state(); - bool req_match = true; - for (size_t i = 0; i < reqs.size(); ++i) { - if (curr_grad_req_[i] != (reqs[i] != kNullOp)) { - curr_grad_req_[i] = reqs[i] != kNullOp; - req_match = false; - } - } - if (!req_match) { + Context default_ctx = inputs[0]->ctx(); + nnvm::Graph& g = info->full_graph; + + if (info->bwd_output_reqs != reqs) { + info->bwd_output_reqs = reqs; + info->bwd_input_eid.clear(); g = nnvm::Graph(); g.outputs = fwd_graph_.outputs; for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) { - if (curr_grad_req_[i]) g.outputs.emplace_back(grad_graph_.outputs[i]); + if (info->bwd_output_reqs[i] == kNullOp) continue; + g.outputs.emplace_back(grad_graph_.outputs[i]); } - bwd_input_eid_.clear(); + g.attrs["context"] = std::make_shared( + std::vector(g.indexed_graph().num_nodes(), default_ctx)); } const auto& idx = g.indexed_graph(); - if (bwd_input_eid_.size() != inputs.size()) { - bwd_input_eid_.clear(); + if (info->bwd_input_eid.size() != inputs.size()) { + info->bwd_input_eid.clear(); for (const auto& i : bwd_ograd_dep_) { auto eid = idx.entry_id(ograd_entries_[i]); - bwd_input_eid_.push_back(eid); + info->bwd_input_eid.push_back(eid); } for (const auto& i : bwd_in_dep_) { auto eid = idx.entry_id(idx.input_nodes()[i], 0); - bwd_input_eid_.push_back(eid); + info->bwd_input_eid.push_back(eid); } for (const auto& i : bwd_out_dep_) { auto eid = idx.entry_id(idx.outputs()[i]); - bwd_input_eid_.push_back(eid); + info->bwd_input_eid.push_back(eid); } - CHECK_EQ(inputs.size(), bwd_input_eid_.size()); + CHECK_EQ(inputs.size(), info->bwd_input_eid.size()); } size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes(); @@ -312,25 +375,22 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph( for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) { for (const auto& j : idx[i].inputs) ++ref_count[idx.entry_id(j)]; } - for (size_t i = 0; i < inputs.size(); ++i) ++ref_count[bwd_input_eid_[i]]; + for (size_t i = 0; i < inputs.size(); ++i) ++ref_count[info->bwd_input_eid[i]]; for (const auto& i : idx.outputs()) ++ref_count[idx.entry_id(i)]; g.attrs["backward_ref_count"] = std::make_shared(std::move(ref_count)); } - ShapeVector shapes(idx.num_node_entries(), TShape()); - DTypeVector dtypes(idx.num_node_entries(), -1); - StorageTypeVector stypes(idx.num_node_entries(), -1); - - for (size_t i = 0; i < num_forward_entries; ++i) { - shapes[i] = state.buff[i].shape(); - dtypes[i] = state.buff[i].dtype(); - stypes[i] = state.buff[i].storage_type(); - } + auto shapes = info->fwd_graph.GetAttr("shape"); + shapes.resize(idx.num_node_entries(), TShape()); + auto dtypes = info->fwd_graph.GetAttr("dtype"); + dtypes.resize(idx.num_node_entries(), -1); + auto stypes = info->fwd_graph.GetAttr("storage_type"); + stypes.resize(idx.num_node_entries(), -1); for (size_t i = 0; i < inputs.size(); ++i) { - shapes[bwd_input_eid_[i]] = inputs[i]->shape(); - dtypes[bwd_input_eid_[i]] = inputs[i]->dtype(); - stypes[bwd_input_eid_[i]] = inputs[i]->storage_type(); + shapes[info->bwd_input_eid[i]] = inputs[i]->shape(); + dtypes[info->bwd_input_eid[i]] = inputs[i]->dtype(); + stypes[info->bwd_input_eid[i]] = inputs[i]->storage_type(); } std::pair node_range, entry_range; @@ -342,79 +402,353 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph( node_range, entry_range); match &= CheckAndInferType(&g, std::move(dtypes), false, node_range, entry_range); - exec::DevMaskVector dev_mask(idx.num_nodes(), inputs[0]->ctx().dev_mask()); + exec::DevMaskVector dev_mask(idx.num_nodes(), default_ctx.dev_mask()); match &= CheckAndInferStorageType(&g, std::move(dev_mask), std::move(stypes), false, node_range, entry_range); if (!match) { g.attrs.erase("backward_mem_plan"); } else if (g.attrs.count("backward_mem_plan")) { - return g; + return true; } StorageVector storage(idx.num_node_entries(), exec::kBadStorageID); + const auto& bwd_stypes = g.GetAttr("storage_type"); + for (size_t i = 0; i < bwd_stypes.size(); i++) { + if (bwd_stypes[i] != kDefaultStorage) storage[i] = exec::kDynamicStorageID; + } for (size_t i = 0; i < num_forward_entries; ++i) storage[i] = exec::kExternalStorageID; for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] = exec::kExternalStorageID; for (const auto i : idx.outputs()) storage[idx.entry_id(i)] = exec::kExternalStorageID; - for (size_t i = 0; i < stypes.size(); i++) { - if (stypes[i] != kDefaultStorage) - storage[i] = exec::kDynamicStorageID; - } auto mem_plan = PlanMemory( &g, std::move(storage), g.GetAttr >("backward_ref_count"), - {num_forward_nodes, idx.num_nodes()}, {num_forward_entries, idx.num_node_entries()}); + {num_forward_nodes, idx.num_nodes()}, + {num_forward_entries, idx.num_node_entries()}, + detect_inplace_addto); g.attrs["backward_mem_plan"] = std::make_shared(std::move(mem_plan)); - return g; + return false; } -void Imperative::CachedOp::Forward( - const std::shared_ptr& op_ptr, - const std::vector& args, - const std::vector& outputs) { +OpStatePtr CachedOp::GetCachedOpState( + const Context& ctx) { + std::lock_guard lock(mutex_); + for (const auto& i : cached_op_states_[ctx]) { + // only create one state per device when not using static memory + if (!config_.static_alloc || i.unique()) { + return i; + } + } + auto state_ptr = OpStatePtr::Create(ctx, fwd_graph_, full_graph_); + + cached_op_states_[ctx].push_back(state_ptr); + return state_ptr; +} + +void CachedOp::StaticAllocMemory( + const OpStatePtr& state_ptr, + bool recording, + bool keep_fwd) { using namespace nnvm; using namespace imperative; - static const auto cached_op = nnvm::Op::Get("_CachedOp"); - CHECK_EQ(args.size(), fwd_args_idx_.size()) - << "CachedOp requires " << fwd_args_idx_.size() - << " inputs but got " << args.size(); + auto& state = state_ptr.get_state(); + const auto& default_ctx = state.context; + nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph; + const auto& idx = g.indexed_graph(); + const auto& vstorage_inplace = g.GetAttr >("storage_inplace_index"); + const auto& mem_plan = g.GetAttr( + keep_fwd ? "backward_mem_plan" : (recording ? "full_mem_plan" : "forward_mem_plan")); + std::vector addto_entry; + if (g.attrs.count("addto_entry")) { + addto_entry = g.GetAttr >("addto_entry"); + } + size_t start_eid = + keep_fwd ? state.info.fwd_graph.indexed_graph().num_node_entries() : 0; + size_t end_eid = idx.num_node_entries(); + + if (!keep_fwd) state.fwd_alloc = false; + state.bwd_alloc = false; + for (size_t i = start_eid; i < state.buff.size(); ++i) { + state.buff[i] = NDArray(); + state.arrays[i] = &state.buff[i]; + state.array_reqs[i] = kNullOp; + state.dynamic_entries[i] = false; + } + + for (auto i : idx.input_nodes()) { + auto eid = idx.entry_id(i, 0); + if (eid >= start_eid) state.dynamic_entries[eid] = true; + } + for (auto i : idx.outputs()) { + auto eid = idx.entry_id(i); + if (eid >= start_eid) state.dynamic_entries[eid] = true; + } + + for (size_t i = start_eid; i < end_eid; ++i) { + if (addto_entry.size() && addto_entry[i]) { + state.array_reqs[i] = kAddTo; + } else if (vstorage_inplace[i] >= 0) { + state.array_reqs[i] = kWriteInplace; + } else if (vstorage_inplace[i] == -2) { + // -2 indicate that the entry is never referenced. + state.array_reqs[i] = kNullOp; + } else { + state.array_reqs[i] = kWriteTo; + } + } + + auto& reuse_pool = keep_fwd ? state.bwd_reuse_pool : state.fwd_reuse_pool; + reuse_pool = imperative::AllocateMemory( + g, idx, default_ctx, start_eid, end_eid, mem_plan, + state.arrays, &state.array_reqs, std::move(reuse_pool)); - Context default_ctx = args[0]->ctx(); + state.recording = recording; + if (keep_fwd) { + state.bwd_alloc = true; + } else { + state.fwd_alloc = true; + } +} +void CachedOp::StaticInitExec( + const OpStatePtr& state_ptr, + bool recording, + bool keep_fwd) { + using namespace nnvm; + using namespace imperative; - std::vector inputs(num_inputs()); - for (index_t i = 0; i < fwd_args_idx_.size(); ++i) { - inputs[fwd_args_idx_[i]] = args[i]; + auto& state = state_ptr.get_state(); + const auto& default_ctx = state.context; + nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph; + const auto& idx = g.indexed_graph(); + std::vector skip_plus_node; + if (g.attrs.count("skip_plus_node")) { + skip_plus_node = g.GetAttr >("skip_plus_node"); } - if (fwd_params_idx_.size()) { - CHECK(params_.find(default_ctx) != params_.end()) - << "CachedOp is not initialized on context " << default_ctx; + size_t start_nid = + keep_fwd ? state.info.fwd_graph.indexed_graph().num_nodes() : 0; + size_t end_nid = idx.num_nodes(); - for (size_t i = 0; i < fwd_params_idx_.size(); ++i) { - inputs[fwd_params_idx_[i]] = ¶ms_[default_ctx][i]; + if (!keep_fwd) state.fwd_exec_init = false; + state.bwd_exec_init = false; + + for (size_t i = start_nid; i < state.execs.size(); ++i) { + state.execs[i].reset(); + state.opr_segs[i] = EngineOprSeg(); + } + + if (!config_.static_shape) { + for (size_t i = start_nid; i < end_nid; ++i) { + state.opr_segs[i].next_nid = i + 1; + state.opr_segs[i].skip = skip_plus_node.size() && skip_plus_node[i]; + } + } else { + for (size_t i = start_nid; i < end_nid; ++i) { + exec::CreateOpExecs(g, &state.execs, i); + } + exec::AttachOpResources(g, state.execs, start_nid, end_nid); + + for (size_t i = start_nid; i < end_nid; ++i) { + bool skip = idx[i].source->is_variable(); + for (size_t j = 0; !skip && j < idx[i].inputs.size(); ++j) { + skip = state.dynamic_entries[idx.entry_id(idx[i].inputs[j])]; + } + for (size_t j = 0; !skip && j < idx[i].source->num_outputs(); ++j) { + skip = state.dynamic_entries[idx.entry_id(i, j)]; + } + if (skip) continue; + SetupOpExec(g, i, state.execs[i], state.arrays, state.array_reqs); } + + size_t bulk_size = idx.num_nodes(); + std::unordered_set excludes; + if (recording || keep_fwd) { + bulk_size = keep_fwd ? config_.backward_bulk_size : config_.forward_bulk_size; + for (const auto& i : idx.outputs()) excludes.insert(idx.entry_id(i)); + for (const auto& i : idx.input_nodes()) excludes.insert(idx.entry_id(i, 0)); + } + + CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size, excludes, + state.execs, skip_plus_node, &state.opr_segs); } - // Initialize + if (keep_fwd) { + state.bwd_exec_init = true; + } else { + state.fwd_exec_init = true; + } +} + +void CachedOp::StaticRunOps( + const Context& default_ctx, + const nnvm::Graph& g, + const OpStatePtr& state_ptr, + size_t start_nid, + size_t end_nid) { + static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); + static auto& is_layer_backward = Op::GetAttr("TIsLayerOpBackward"); + + bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning; + bool is_training = Imperative::Get()->is_training(); + auto& state = state_ptr.get_state(); + const auto& idx = g.indexed_graph(); + const auto& dispatch_modes = g.GetAttr("dispatch_mode"); + const auto& op_execs = state.execs; + + std::vector ndinputs, ndoutputs; + nnvm::ShapeVector arg_shapes; + nnvm::DTypeVector arg_dtypes; + std::vector req; + + for (size_t i = start_nid; config_.static_shape && i < end_nid; ++i) { + if (op_execs[i]) op_execs[i]->op_ctx.is_train = is_training; + } + + for (size_t i = start_nid; i < end_nid; i = state.opr_segs[i].next_nid) { + const auto& opr_seg = state.opr_segs[i]; + if (opr_seg.skip) continue; + if (opr_seg.opr != nullptr) { + Engine::Get()->Push(opr_seg.opr.get(), default_ctx, 0, profiling); + } else { + const nnvm::IndexedGraph::Node& node = idx[i]; + if (node.source->is_variable()) continue; + auto num_outputs = node.source->num_outputs(); + ndinputs.clear(); + ndinputs.reserve(node.inputs.size()); + for (const auto& j : node.inputs) { + ndinputs.emplace_back(state.arrays[idx.entry_id(j)]); + CHECK(!ndinputs.back()->is_none()); + } + ndoutputs.clear(); + ndoutputs.reserve(num_outputs); + req.clear(); + req.reserve(num_outputs); + for (size_t j = 0; j < num_outputs; ++j) { + size_t eid = idx.entry_id(i, j); + ndoutputs.emplace_back(state.arrays[eid]); + req.push_back(state.array_reqs[eid]); + CHECK(req.back() == kNullOp || !ndoutputs.back()->is_none()); + } + const DispatchMode dispatch_mode = dispatch_modes[i]; + if (createop.count(node.source->op())) { + arg_shapes.clear(); + arg_dtypes.clear(); + arg_shapes.reserve(ndinputs.size()); + arg_dtypes.reserve(ndinputs.size()); + for (size_t i = 0; i < ndinputs.size(); ++i) { + arg_shapes.emplace_back(ndinputs[i]->shape()); + arg_dtypes.emplace_back(ndinputs[i]->dtype()); + } + state.op_states[i] = createop[node.source->op()]( + node.source->attrs, default_ctx, arg_shapes, arg_dtypes); + Imperative::Get()->InvokeOp( + default_ctx, node.source->attrs, ndinputs, ndoutputs, req, + dispatch_mode, state.op_states[i]); + } else if (is_layer_backward.get(node.source->op(), false)) { + nnvm::Node* fwd_node = node.source->control_deps[0].get(); + auto fwd_node_id = idx.node_id(fwd_node); + Imperative::Get()->InvokeOp( + default_ctx, node.source->attrs, ndinputs, ndoutputs, + req, dispatch_mode, state.op_states[fwd_node_id]); + } else { + Imperative::Get()->InvokeOp( + default_ctx, node.source->attrs, ndinputs, ndoutputs, req, + dispatch_mode); + } + } + } +} + +OpStatePtr CachedOp::StaticForward( + const Context& default_ctx, + const std::vector& inputs, + const std::vector& outputs) { + using namespace nnvm; + using namespace imperative; + bool recording = Imperative::Get()->is_recording(); - nnvm::Graph g = GetForwardGraph(recording, inputs); + auto state_ptr = GetCachedOpState(default_ctx); + auto& state = state_ptr.get_state(); + std::lock_guard lock(state.mutex); + + bool match = SetForwardGraph(&state.info, recording, inputs); + match = match && state.recording == recording; + + nnvm::Graph& g = state.info.fwd_graph; const auto& idx = g.indexed_graph(); - size_t num_inputs = idx.input_nodes().size(); + if (!state.fwd_alloc || !match) { + StaticAllocMemory(state_ptr, recording, false); + } - for (size_t i = 0; i < inputs.size(); ++i) { - CHECK_EQ(inputs[i]->ctx(), default_ctx) - << "CachedOp requires all inputs to live on the same context. But " - << idx[idx.input_nodes()[0]].source->attrs.name << " is on " << default_ctx - << " while " << idx[idx.input_nodes()[i]].source->attrs.name << " is on " - << inputs[i]->ctx(); + if (config_.static_shape) { + for (auto i : config_.param_indices) { + auto nid = idx.input_nodes()[i]; + if (!state.arrays[idx.entry_id(nid, 0)]->IsSame(*inputs[i])) { + match = false; + auto ptr = &state.buff[idx.entry_id(nid, 0)]; + CHECK_EQ(state.arrays[idx.entry_id(nid, 0)], ptr); + *state.arrays[idx.entry_id(nid, 0)] = *inputs[i]; + state.dynamic_entries[idx.entry_id(nid, 0)] = false; + } + } + for (auto i : config_.data_indices) { + auto eid = idx.entry_id(idx.input_nodes()[i], 0); + state.arrays[eid] = inputs[i]; + } + } else { + for (size_t i = 0; i < num_inputs(); ++i) { + auto nid = idx.input_nodes()[i]; + state.arrays[idx.entry_id(nid, 0)] = inputs[i]; + } } - auto op_state_ptr = OpStatePtr::Create(); - auto& cached_op_state = op_state_ptr.get_state(); - auto& buff = cached_op_state.buff; - auto& states = cached_op_state.states; + if (!state.fwd_exec_init || !match) { + StaticInitExec(state_ptr, recording, false); + } + + const auto& dtypes = g.GetAttr("dtype"); + const auto& shapes = g.GetAttr("shape"); + const auto& stypes = g.GetAttr("storage_type"); + + for (size_t i = 0; i < outputs.size(); ++i) { + auto eid = idx.entry_id(idx.outputs()[i]); + state.arrays[eid] = outputs[i]; + if (!outputs[i]->is_none()) continue; + *outputs[i] = NDArray(static_cast(stypes[eid]), + shapes[eid], default_ctx, true, dtypes[eid]); + } + + StaticRunOps(default_ctx, g, state_ptr, 0, idx.num_nodes()); + + return recording ? state_ptr : OpStatePtr(); +} + + +OpStatePtr CachedOp::DynamicForward( + const Context& default_ctx, + const std::vector& inputs, + const std::vector& outputs) { + using namespace nnvm; + using namespace imperative; + + // Initialize + bool recording = Imperative::Get()->is_recording(); + auto op_state = OpStatePtr::Create(); + auto& runtime = op_state.get_state(); + { + auto state_ptr = GetCachedOpState(default_ctx); + auto& state = state_ptr.get_state(); + std::lock_guard lock(state.mutex); + SetForwardGraph(&state.info, recording, inputs); + runtime.info.fwd_graph = state.info.fwd_graph; + } + nnvm::Graph& g = runtime.info.fwd_graph; + const auto& idx = g.indexed_graph(); + size_t num_inputs = idx.input_nodes().size(); + auto& buff = runtime.buff; + auto& states = runtime.op_states; // Allocate entries states.resize(idx.num_nodes()); @@ -446,57 +780,98 @@ void Imperative::CachedOp::Forward( AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(), mem_plan, arrays, &array_reqs); + const auto& dtypes = g.GetAttr("dtype"); + const auto& shapes = g.GetAttr("shape"); + const auto& stypes = g.GetAttr("storage_type"); + + for (size_t i = 0; i < outputs.size(); ++i) { + auto eid = idx.entry_id(idx.outputs()[i]); + arrays[eid] = outputs[i]; + if (!outputs[i]->is_none()) continue; + *outputs[i] = NDArray(static_cast(stypes[eid]), + shapes[eid], default_ctx, true, dtypes[eid]); + } + const auto& dispatch_modes = g.GetAttr("dispatch_mode"); if (recording && !inlining_) Imperative::Get()->set_is_recording(false); - int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size); - Imperative::Get()->RunGraph( - false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), - std::move(ref_count), &states, dispatch_modes); + RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), + std::move(ref_count), &states, dispatch_modes); - Engine::Get()->set_bulk_size(prev_bulk_size); Imperative::Get()->set_is_recording(recording); - for (size_t i = 0; i < idx.num_node_entries(); ++i) { - if (arrays[i] == &buff[i]) continue; - buff[i].shape_ = arrays[i]->shape_; - buff[i].dtype_ = arrays[i]->dtype_; - buff[i].storage_type_ = arrays[i]->storage_type_; + return op_state; +} + +void CachedOp::Forward( + const std::shared_ptr& op_ptr, + const std::vector& inputs, + const std::vector& outputs) { + static const auto cached_op = nnvm::Op::Get("_CachedOp"); + + CHECK_EQ(inputs.size(), num_inputs()); + + Context default_ctx = inputs[0]->ctx(); + + const auto& idx = fwd_graph_.indexed_graph(); + for (size_t i = 0; i < inputs.size(); ++i) { + CHECK_EQ(inputs[i]->ctx(), default_ctx) + << "CachedOp requires all inputs to live on the same context. But " + << idx[idx.input_nodes()[0]].source->attrs.name + << " is on " << default_ctx << " while " + << idx[idx.input_nodes()[i]].source->attrs.name + << " is on " << inputs[i]->ctx(); } - if (recording && !inlining_) { + int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size); + + OpStatePtr op_state; + if (config_.static_alloc) { + op_state = StaticForward(default_ctx, inputs, outputs); + } else { + op_state = DynamicForward(default_ctx, inputs, outputs); + } + + Engine::Get()->set_bulk_size(prev_bulk_size); + + if (Imperative::Get()->is_recording() && !inlining_) { nnvm::NodeAttrs attrs; attrs.op = cached_op; attrs.name = "_cachedop"; attrs.parsed = op_ptr; Imperative::Get()->RecordOp( - std::move(attrs), inputs, outputs, op_state_ptr, + std::move(attrs), inputs, outputs, op_state, &save_inputs(), &save_outputs()); } } -void Imperative::CachedOp::Backward( +void CachedOp::DynamicBackward( const bool retain_graph, - const OpStatePtr& state, + const OpStatePtr& op_state, const std::vector& inputs, const std::vector& reqs, const std::vector& outputs) { using namespace nnvm; using namespace imperative; - CHECK(!Imperative::Get()->is_recording()) - << "CachedOp does not support higher order gradients. " - << "If you want to do backward with create_graph=True please " - << "do not use hybridize."; // Initialize - nnvm::Graph g = GetBackwardGraph(state, reqs, inputs); + Context default_ctx = outputs[0]->ctx(); + auto& runtime = op_state.get_state(); + { + auto state_ptr = GetCachedOpState(default_ctx); + auto& state = state_ptr.get_state(); + std::lock_guard lock(state.mutex); + state.info.fwd_graph = runtime.info.fwd_graph; + SetBackwardGraph(&state.info, reqs, inputs); + runtime.info.full_graph = state.info.full_graph; + runtime.info.bwd_input_eid = state.info.bwd_input_eid; + } + nnvm::Graph& g = runtime.info.full_graph; const auto& idx = g.indexed_graph(); - - auto& cached_op_state = state.get_state(); - auto& buff = cached_op_state.buff; - auto& states = cached_op_state.states; + auto& buff = runtime.buff; + auto& states = runtime.op_states; size_t num_forward_outputs = fwd_graph_.outputs.size(); size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes(); @@ -506,7 +881,7 @@ void Imperative::CachedOp::Backward( arrays.reserve(buff.size()); for (size_t i = 0; i < buff.size(); ++i) arrays.push_back(&buff[i]); for (size_t i = 0; i < inputs.size(); ++i) { - arrays[bwd_input_eid_[i]] = inputs[i]; + arrays[runtime.info.bwd_input_eid[i]] = inputs[i]; } for (size_t i = 0, j = num_forward_outputs; i < reqs.size(); ++i) { if (reqs[i] == kNullOp) continue; @@ -530,20 +905,14 @@ void Imperative::CachedOp::Backward( if (ref_count[i] == 0) array_reqs[i] = kNullOp; } - Context default_ctx = outputs[0]->ctx(); const auto& mem_plan = g.GetAttr("backward_mem_plan"); AllocateMemory(g, idx, default_ctx, num_forward_entries, idx.num_node_entries(), mem_plan, arrays, &array_reqs); const auto& dispatch_modes = g.GetAttr("dispatch_mode"); - int prev_bulk_size = Engine::Get()->set_bulk_size(config_.backward_bulk_size); - - Imperative::Get()->RunGraph( - retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(), - std::move(array_reqs), std::move(ref_count), &states, dispatch_modes); - - Engine::Get()->set_bulk_size(prev_bulk_size); + RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(), + std::move(array_reqs), std::move(ref_count), &states, dispatch_modes); if (retain_graph) { buff.resize(num_forward_entries); @@ -553,6 +922,99 @@ void Imperative::CachedOp::Backward( } } +void CachedOp::StaticBackward( + const bool retain_graph, + const OpStatePtr& state_ptr, + const std::vector& inputs, + const std::vector& reqs, + const std::vector& outputs) { + using namespace nnvm; + using namespace imperative; + + Context default_ctx = outputs[0]->ctx(); + + auto& state = state_ptr.get_state(); + std::lock_guard lock(state.mutex); + + bool match = SetBackwardGraph(&state.info, reqs, inputs, true); + + nnvm::Graph& g = state.info.full_graph; + const auto& idx = g.indexed_graph(); + auto num_forward_nodes = state.info.fwd_graph.indexed_graph().num_nodes(); + + if (!state.bwd_alloc || !match) { + StaticAllocMemory(state_ptr, true, true); + } + + if (config_.static_shape) { + for (auto i : config_.param_indices) { + const auto iter = fwd_input_to_grad_output_.find(i); + if (iter == fwd_input_to_grad_output_.end()) continue; + auto entry = grad_graph_.outputs[iter->second]; + if (!idx.exist(entry.node.get())) continue; + auto eid = idx.entry_id(entry); + if (!state.arrays[eid]->IsSame(*outputs[iter->second]) || + !(state.array_reqs[eid] == reqs[iter->second])) { + match = false; + state.array_reqs[eid] = reqs[iter->second]; + *state.arrays[eid] = *outputs[iter->second]; + state.dynamic_entries[eid] = false; + } + } + for (auto i : config_.data_indices) { + const auto iter = fwd_input_to_grad_output_.find(i); + if (iter == fwd_input_to_grad_output_.end()) continue; + auto entry = grad_graph_.outputs[iter->second]; + if (!idx.exist(entry.node.get())) continue; + auto eid = idx.entry_id(entry); + state.array_reqs[eid] = reqs[iter->second]; + state.arrays[eid] = outputs[iter->second]; + } + } else { + for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) { + auto entry = grad_graph_.outputs[i]; + if (!idx.exist(entry.node.get())) continue; + auto eid = idx.entry_id(entry); + state.array_reqs[eid] = reqs[i]; + state.arrays[eid] = outputs[i]; + } + } + + if (!state.bwd_exec_init || !match) { + StaticInitExec(state_ptr, true, true); + } + + for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) { + auto eid = state.info.bwd_input_eid[i]; + if (state.dynamic_entries[eid]) state.arrays[eid] = inputs[i]; + } + + StaticRunOps(default_ctx, g, state_ptr, num_forward_nodes, idx.num_nodes()); +} + +void CachedOp::Backward( + const bool retain_graph, + const OpStatePtr& state, + const std::vector& inputs, + const std::vector& reqs, + const std::vector& outputs) { + using namespace imperative; + CHECK(!Imperative::Get()->is_recording()) + << "CachedOp does not support higher order gradients. " + << "If you want to do backward with create_graph=True please " + << "do not use hybridize."; + + int prev_bulk_size = Engine::Get()->set_bulk_size(config_.backward_bulk_size); + + if (config_.static_alloc) { + StaticBackward(retain_graph, state, inputs, reqs, outputs); + } else { + DynamicBackward(retain_graph, state, inputs, reqs, outputs); + } + + Engine::Get()->set_bulk_size(prev_bulk_size); +} + NNVM_REGISTER_OP(_CachedOp) .set_num_inputs([](const NodeAttrs& attrs) { diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h new file mode 100644 index 000000000000..60a40c5e4a52 --- /dev/null +++ b/src/imperative/cached_op.h @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_IMPERATIVE_CACHED_OP_H_ +#define MXNET_IMPERATIVE_CACHED_OP_H_ + +#include +#include +#include +#include +#include +#include + +namespace mxnet { +/*! \brief CachedOp Parameters */ +struct CachedOpConfig : public dmlc::Parameter { + uint32_t inline_limit; + uint32_t forward_bulk_size; + uint32_t backward_bulk_size; + bool static_alloc; + bool static_shape; + nnvm::Tuple data_indices; + nnvm::Tuple param_indices; + DMLC_DECLARE_PARAMETER(CachedOpConfig) { + DMLC_DECLARE_FIELD(static_alloc) + .set_default(false) + .describe("Statically allocate memory to improve speed. " + "Memory usage may increase."); + DMLC_DECLARE_FIELD(static_shape) + .set_default(false) + .describe("Optimize for invariant input shapes between iterations. " + "Must also set static_alloc to True. " + "Change of input shapes is still allowed but slower."); + DMLC_DECLARE_FIELD(inline_limit) + .set_default(2) + .describe("Maximum number of operators that can be inlined."); + DMLC_DECLARE_FIELD(forward_bulk_size) + .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15)) + .describe("Segment size of bulk execution during forward pass."); + DMLC_DECLARE_FIELD(backward_bulk_size) + .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15)) + .describe("Segment size of bulk execution during backward pass."); + DMLC_DECLARE_FIELD(data_indices) + .set_default(nnvm::Tuple()) + .describe("Position of argument variables."); + DMLC_DECLARE_FIELD(param_indices) + .set_default(nnvm::Tuple()) + .describe("Position of parameters."); + } +}; + +class CachedOp { + public: + CachedOp( + const nnvm::Symbol& sym, + const std::vector >& flags); + ~CachedOp(); + uint32_t num_inputs() { + return fwd_graph_.indexed_graph().input_nodes().size(); + } + uint32_t num_outputs() { + return fwd_graph_.outputs.size(); + } + uint32_t num_backward_inputs() { + return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size(); + } + std::vector& save_inputs() { + return save_inputs_; + } + std::vector& save_outputs() { + return save_outputs_; + } + const std::unordered_set& mutable_input_nodes() { + return fwd_graph_.indexed_graph().mutable_input_nodes(); + } + std::vector Gradient( + const nnvm::NodePtr& node, + const std::vector& ograds); + void Forward( + const std::shared_ptr& op_ptr, + const std::vector& inputs, + const std::vector& outputs); + void Backward( + const bool retain_graph, + const OpStatePtr& state, + const std::vector& inputs, + const std::vector& reqs, + const std::vector& outputs); + + private: + struct GraphInfo; + struct DynamicRuntime; + struct CachedOpState; + + OpStatePtr GetCachedOpState(const Context& ctx); + bool SetForwardGraph( + GraphInfo* info, + const bool recording, + const std::vector& inputs); + bool SetBackwardGraph( + GraphInfo* info, + const std::vector& reqs, + const std::vector& inputs, + bool detect_inplace_addto = false); + OpStatePtr DynamicForward( + const Context& default_ctx, + const std::vector& inputs, + const std::vector& outputs); + void DynamicBackward( + const bool retain_graph, + const OpStatePtr& op_state, + const std::vector& inputs, + const std::vector& reqs, + const std::vector& outputs); + void StaticAllocMemory( + const OpStatePtr& state_ptr, + bool recording, + bool keep_fwd); + void StaticInitExec( + const OpStatePtr& state_ptr, + bool recording, + bool keep_fwd); + void StaticRunOps( + const Context& default_ctx, + const nnvm::Graph& g, + const OpStatePtr& state_ptr, + size_t start_nid, + size_t end_nid); + OpStatePtr StaticForward( + const Context& default_ctx, + const std::vector& inputs, + const std::vector& outputs); + void StaticBackward( + const bool retain_graph, + const OpStatePtr& state_ptr, + const std::vector& inputs, + const std::vector& reqs, + const std::vector& outputs); + + CachedOpConfig config_; + nnvm::Graph fwd_graph_; + nnvm::Graph grad_graph_; + nnvm::Graph full_graph_; + bool inlining_; + std::vector ograd_entries_; + std::vector bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_; + std::unordered_map fwd_input_to_grad_output_; + std::vector save_inputs_, save_outputs_; + std::vector bwd_output_reqs_; + + std::mutex mutex_; + std::unordered_map > cached_op_states_; +}; + +using CachedOpPtr = std::shared_ptr; + +} // namespace mxnet +#endif // MXNET_IMPERATIVE_CACHED_OP_H_ diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index 7caf305eac75..e1654259a2fb 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -19,6 +19,7 @@ #include #include #include "./imperative_utils.h" +#include "./cached_op.h" namespace mxnet { #if DMLC_CXX11_THREAD_LOCAL @@ -266,95 +267,6 @@ void Imperative::RecordOp( } } -void Imperative::RunGraph( - const bool retain_graph, - const nnvm::IndexedGraph& idx, - const std::vector arrays, - size_t node_start, size_t node_end, - std::vector&& array_reqs, - std::vector&& ref_count, - std::vector *p_states, - const DispatchModeVector &dispatch_modes) { - using namespace nnvm; - using namespace imperative; - static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); - static auto& is_layer_backward = Op::GetAttr("TIsLayerOpBackward"); - static const auto bwd_cached_op = Op::Get("_backward_CachedOp"); - - std::vector& states = *p_states; - bool recording = is_recording(); - - std::vector ndinputs, ndoutputs; - ShapeVector arg_shapes; - DTypeVector arg_dtypes; - std::vector req; - - for (size_t i = node_start; i < node_end; ++i) { - const nnvm::IndexedGraph::Node& node = idx[i]; - if (node.source->op() == nullptr) continue; - auto num_outputs = node.source->num_outputs(); - ndinputs.clear(); - ndinputs.reserve(node.inputs.size()); - for (const auto& j : node.inputs) { - ndinputs.emplace_back(arrays[idx.entry_id(j)]); - CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name << " " << j.index; - } - ndoutputs.clear(); - ndoutputs.reserve(num_outputs); - req.clear(); - req.reserve(num_outputs); - for (size_t j = 0; j < num_outputs; ++j) { - size_t eid = idx.entry_id(i, j); - ndoutputs.emplace_back(arrays[eid]); - req.push_back(array_reqs[eid]); - CHECK(!ndoutputs.back()->is_none()); - } - const Context& ctx = ndoutputs[0]->ctx(); - const DispatchMode dispatch_mode = dispatch_modes[i]; - if (node.source->op() == bwd_cached_op) { - const auto& cached_op = dmlc::get(node.source->attrs.parsed); - nnvm::Node* fwd_node = node.source->control_deps[0].get(); - auto fwd_node_id = idx.node_id(fwd_node); - cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs); - } else if (createop.count(node.source->op())) { - arg_shapes.clear(); - arg_dtypes.clear(); - arg_shapes.reserve(ndinputs.size()); - arg_dtypes.reserve(ndinputs.size()); - for (size_t i = 0; i < ndinputs.size(); ++i) { - arg_shapes.emplace_back(ndinputs[i]->shape()); - arg_dtypes.emplace_back(ndinputs[i]->dtype()); - } - states[i] = createop[node.source->op()]( - node.source->attrs, ctx, arg_shapes, arg_dtypes); - InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, states[i]); - if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[i]); - } else if (is_layer_backward.get(node.source->op(), false)) { - nnvm::Node* fwd_node = node.source->control_deps[0].get(); - auto fwd_node_id = idx.node_id(fwd_node); - InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, - req, dispatch_mode, states[fwd_node_id]); - if (recording) { - RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[fwd_node_id]); - } - } else { - InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode); - if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs); - } - - for (const auto& j : node.inputs) { - size_t eid = idx.entry_id(j); - --ref_count[eid]; - if (ref_count[eid] == 0) arrays[eid]->ptr_.reset(); - } - for (size_t j = 0; j < ndoutputs.size(); ++j) { - size_t eid = idx.entry_id(i, j); - if (ref_count[eid] == 0) arrays[eid]->ptr_.reset(); - } - } -} - - std::vector Imperative::Backward( const std::vector& outputs, const std::vector& ograds, diff --git a/src/imperative/imperative_utils.cc b/src/imperative/imperative_utils.cc new file mode 100644 index 000000000000..464aefc220de --- /dev/null +++ b/src/imperative/imperative_utils.cc @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./imperative_utils.h" +#include "./cached_op.h" + +namespace mxnet { +namespace imperative { +void RunGraph( + const bool retain_graph, + const nnvm::IndexedGraph& idx, + const std::vector arrays, + size_t node_start, size_t node_end, + std::vector&& array_reqs, + std::vector&& ref_count, + std::vector *p_states, + const DispatchModeVector &dispatch_modes) { + using namespace nnvm; + using namespace imperative; + static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); + static auto& is_layer_backward = Op::GetAttr("TIsLayerOpBackward"); + static const auto bwd_cached_op = Op::Get("_backward_CachedOp"); + + const auto imp = Imperative::Get(); + + std::vector& states = *p_states; + bool recording = imp->is_recording(); + + std::vector ndinputs, ndoutputs; + ShapeVector arg_shapes; + DTypeVector arg_dtypes; + std::vector req; + + for (size_t i = node_start; i < node_end; ++i) { + const nnvm::IndexedGraph::Node& node = idx[i]; + if (node.source->op() == nullptr) continue; + auto num_outputs = node.source->num_outputs(); + ndinputs.clear(); + ndinputs.reserve(node.inputs.size()); + for (const auto& j : node.inputs) { + ndinputs.emplace_back(arrays[idx.entry_id(j)]); + CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name << " " << j.index; + } + ndoutputs.clear(); + ndoutputs.reserve(num_outputs); + req.clear(); + req.reserve(num_outputs); + for (size_t j = 0; j < num_outputs; ++j) { + size_t eid = idx.entry_id(i, j); + ndoutputs.emplace_back(arrays[eid]); + req.push_back(array_reqs[eid]); + CHECK(array_reqs[eid] == kNullOp || !ndoutputs.back()->is_none()); + } + const Context& ctx = ndoutputs[0]->ctx(); + const DispatchMode dispatch_mode = dispatch_modes[i]; + if (node.source->op() == bwd_cached_op) { + const auto& cached_op = dmlc::get(node.source->attrs.parsed); + nnvm::Node* fwd_node = node.source->control_deps[0].get(); + auto fwd_node_id = idx.node_id(fwd_node); + cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs); + } else if (createop.count(node.source->op())) { + arg_shapes.clear(); + arg_dtypes.clear(); + arg_shapes.reserve(ndinputs.size()); + arg_dtypes.reserve(ndinputs.size()); + for (size_t i = 0; i < ndinputs.size(); ++i) { + arg_shapes.emplace_back(ndinputs[i]->shape()); + arg_dtypes.emplace_back(ndinputs[i]->dtype()); + } + states[i] = createop[node.source->op()]( + node.source->attrs, ctx, arg_shapes, arg_dtypes); + imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, states[i]); + if (recording) { + imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[i]); + } + } else if (is_layer_backward.get(node.source->op(), false)) { + nnvm::Node* fwd_node = node.source->control_deps[0].get(); + auto fwd_node_id = idx.node_id(fwd_node); + imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, + req, dispatch_mode, states[fwd_node_id]); + if (recording) { + imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[fwd_node_id]); + } + } else { + imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode); + if (recording) { + imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs); + } + } + + for (const auto& j : node.inputs) { + size_t eid = idx.entry_id(j); + --ref_count[eid]; + if (ref_count[eid] == 0) *arrays[eid] = NDArray(); + } + for (size_t j = 0; j < ndoutputs.size(); ++j) { + size_t eid = idx.entry_id(i, j); + if (ref_count[eid] == 0) *arrays[eid] = NDArray(); + } + } +} + +} // namespace imperative +} // namespace mxnet diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 06b7e058dd14..726531d02994 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include "../executor/graph_executor.h" #include "../executor/exec_pass.h" @@ -38,11 +39,24 @@ namespace mxnet { namespace imperative { struct MemoryPlanInfo { - uint32_t sid; + int storage_id; + uint32_t root; size_t size; bool inplace; }; +struct EngineOprDeleter { + void operator()(engine::Opr* handle) { + Engine::Get()->DeleteOperator(handle); + } +}; + +struct EngineOprSeg { + bool skip; + size_t next_nid; + std::unique_ptr opr; +}; + using MemoryPlanVector = std::vector; inline Context GetContext(const nnvm::NodeAttrs& attrs, @@ -715,10 +729,12 @@ inline std::vector PlaceDevice(const nnvm::IndexedGraph& idx) { inline MemoryPlanVector PlanMemory( - nnvm::Graph* p_g, nnvm::StorageVector&& storage, + nnvm::Graph* p_g, + nnvm::StorageVector&& storage, const std::vector& ref_count, const std::pair& node_range = {0, 0}, - const std::pair& entry_range = {0, 0}) { + const std::pair& entry_range = {0, 0}, + bool detect_inplace_addto = false) { using namespace nnvm; nnvm::Graph& g = *p_g; const auto& idx = g.indexed_graph(); @@ -728,31 +744,31 @@ inline MemoryPlanVector PlanMemory( g.attrs["ref_count"] = std::make_shared(ref_count); g.attrs["storage"] = std::make_shared(std::move(storage)); g = nnvm::ApplyPass(g, "PlanMemory"); + if (detect_inplace_addto) g = exec::DetectInplaceAddTo(g); const auto& dtypes = g.GetAttr("dtype"); const auto& shapes = g.GetAttr("shape"); - const auto& stypes = g.GetAttr("storage_type"); - auto storage_ids = g.MoveCopyAttr("storage_id"); - auto storage_inplace = g.MoveCopyAttr >("storage_inplace_index"); + const auto& storage_inplace = g.GetAttr >("storage_inplace_index"); + const auto& storage_ids = g.GetAttr("storage_id"); uint32_t entry_start = entry_range.first; uint32_t entry_end = entry_range.second > entry_start ? entry_range.second : idx.num_node_entries(); MemoryPlanVector mem_plan(idx.num_node_entries()); - std::unordered_map sid_to_loc; + std::unordered_map sid_to_root; for (uint32_t i = entry_start; i < entry_end; ++i) { - if (stypes[i] != kDefaultStorage) continue; if (storage_ids[i] < 0) { - mem_plan[i] = {i, mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(), false}; - } else if (!sid_to_loc.count(storage_ids[i])) { + mem_plan[i] = {storage_ids[i], i, 0, false}; + } else if (!sid_to_root.count(storage_ids[i])) { CHECK_LT(storage_inplace[i], 0); - sid_to_loc[storage_ids[i]] = i; - mem_plan[i].sid = i; - mem_plan[i].size = mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(); + sid_to_root[storage_ids[i]] = i; + mem_plan[i] = {storage_ids[i], i, + mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(), + false}; } else { - uint32_t loc = sid_to_loc[storage_ids[i]]; - mem_plan[i] = {loc, 0, storage_inplace[i] >= 0}; - mem_plan[loc].size = std::max(mem_plan[loc].size, + uint32_t root = sid_to_root[storage_ids[i]]; + mem_plan[i] = {storage_ids[i], root, 0, storage_inplace[i] >= 0}; + mem_plan[root].size = std::max(mem_plan[root].size, mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size()); } } @@ -761,39 +777,213 @@ inline MemoryPlanVector PlanMemory( } -inline void AllocateMemory(const nnvm::Graph& g, - const nnvm::IndexedGraph& idx, - const Context& default_ctx, - const uint32_t entry_start, const uint32_t entry_end, - const MemoryPlanVector& mem_plan, - const std::vector& arrays, - std::vector *array_reqs) { +inline std::multimap AllocateMemory( + const nnvm::Graph& g, + const nnvm::IndexedGraph& idx, + const Context& default_ctx, + const uint32_t entry_start, const uint32_t entry_end, + const MemoryPlanVector& mem_plan, + const std::vector& arrays, + std::vector *array_reqs, + std::multimap&& pool = std::multimap()) { using namespace nnvm; const auto& dtypes = g.GetAttr("dtype"); const auto& shapes = g.GetAttr("shape"); const auto& stypes = g.GetAttr("storage_type"); + std::multimap new_pool; + for (uint32_t i = entry_start; i < entry_end; ++i) { - if (!arrays[i]->is_none()) continue; - if (stypes[i] == kDefaultStorage) { - if (mem_plan[i].sid == i) { - CHECK_GT(mem_plan[i].size, 0); + if (mem_plan[i].storage_id == exec::kExternalStorageID) continue; + CHECK(arrays[i]->is_none()); + if (mem_plan[i].storage_id == exec::kDynamicStorageID) { + *arrays[i] = NDArray(static_cast(stypes[i]), + shapes[i], default_ctx, true, dtypes[i]); + continue; + } + CHECK_EQ(stypes[i], kDefaultStorage); + if (mem_plan[i].root == i) { + CHECK_GT(mem_plan[i].size, 0); + auto iter = pool.lower_bound(mem_plan[i].size); + if (iter != pool.end()) { + *arrays[i] = iter->second.AsArray(shapes[i], dtypes[i]); + new_pool.insert(*iter); + pool.erase(iter); + } else { NDArray buff(TShape({static_cast(mem_plan[i].size)}), default_ctx, true, mshadow::kUint8); *arrays[i] = buff.AsArray(shapes[i], dtypes[i]); + new_pool.insert({mem_plan[i].size, buff}); + } + } else { + CHECK_GE(mem_plan[mem_plan[i].root].storage_id, 0); + *arrays[i] = arrays[mem_plan[i].root]->AsArray(shapes[i], dtypes[i]); + if (mem_plan[i].inplace && array_reqs->at(i) == kWriteTo) { + array_reqs->at(i) = kWriteInplace; + } + } + } + + return new_pool; +} + +inline void SetupOpExec( + const nnvm::Graph& g, + size_t nid, + const std::shared_ptr& exec, + const std::vector arrays, + const std::vector array_reqs) { + const auto& idx = g.indexed_graph(); + const auto& inode = idx[nid]; + CHECK_EQ(exec->in_array.size(), 0U); + CHECK_EQ(exec->out_array.size(), 0U); + for (const auto& e : inode.inputs) { + CHECK(!arrays[idx.entry_id(e)]->is_none()) << inode.source->attrs.name; + exec->in_array.push_back(*arrays[idx.entry_id(e)]); + } + for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { + uint32_t eid = idx.entry_id(nid, index); + CHECK(!arrays[eid]->is_none()) << inode.source->attrs.name; + exec->out_array.push_back(*arrays[eid]); + exec->req.push_back(array_reqs[eid]); + } + + exec->Setup(); +} + +inline Engine::OprHandle CreateEngineOp( + const Context& default_ctx, + const std::vector >& execs) { + CHECK_GT(execs.size(), 0); + std::vector use_vars, mutate_vars; + + for (const auto& exec : execs) { + CHECK_GT(exec->out_array.size(), 0); + CHECK(execs.size() == 1 || exec->exec_type() == ExecType::kSync); + + // the variables + for (const auto& nd : exec->in_array) { + use_vars.push_back(nd.var()); + } + for (auto& r : exec->op_ctx.requested) { + mutate_vars.push_back(r.var); + } + for (auto& nd : exec->out_array) { + mutate_vars.push_back(nd.var()); + } + if (exec->var() != nullptr) { + mutate_vars.push_back(exec->var()); + } + } + + // dedup vars + Engine::Get()->DeduplicateVarHandle(&use_vars, &mutate_vars); + bool is_gpu = default_ctx.dev_mask() == gpu::kDevMask; + bool is_async = execs.size() > 1 ? false : execs[0]->exec_type() == ExecType::kAsync; + + auto exec_fun = [execs, is_async, is_gpu] ( + RunContext ctx, Engine::CallbackOnComplete on_complete) { + if (is_async) { + execs[0]->op_ctx.async_on_complete = on_complete; + } + for (const auto& exec : execs) exec->Run(ctx, is_gpu); + // call on complete only if it is async op + if (!is_async) { + if (is_gpu) { + #if MXNET_USE_CUDA + // Wait GPU kernel to finish. + ctx.get_stream()->Wait(); + #else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; + #endif + } + on_complete(); + } + }; + + return Engine::Get()->NewOperator( + exec_fun, use_vars, mutate_vars, FnProperty::kNormal); +} + +inline void CreateEngineOpSeg( + const nnvm::IndexedGraph& idx, + const Context default_ctx, + const size_t start_nid, + const size_t end_nid, + const size_t bulk_size, + const std::unordered_set& excludes, + const std::vector >& execs, + const std::vector skip_plus_node, + std::vector *opr_segs) { + size_t seg_start = start_nid; + std::vector > seg_execs; + for (size_t nid = start_nid; nid < end_nid; ++nid) { + const auto& node = idx[nid]; + if (node.source->is_variable()) continue; + if (skip_plus_node.size() && skip_plus_node[nid]) continue; + auto& exec = execs[nid]; + bool is_async = exec->exec_type() != ExecType::kSync; + bool valid = exec->out_array.size() > 0; + + // Stop at async nodes and invalid node (due to input/output is not allocated) + bool stop = is_async || !valid || seg_execs.size() >= bulk_size; + for (size_t i = 0; i < node.inputs.size() && !stop; ++i) { + if (excludes.count(idx.entry_id(node.inputs[i]))) stop = true; + } + auto num_outputs = node.source->num_outputs(); + for (size_t i = 0; i < num_outputs && !stop; ++i) { + if (excludes.count(idx.entry_id(nid, i))) stop = true; + } + + // Create opr segment for previous nodes. + if (stop && nid > seg_start) { + auto& seg = (*opr_segs)[seg_start]; + if (seg_execs.size()) { + seg = EngineOprSeg{false, nid}; + seg.opr.reset(CreateEngineOp(default_ctx, seg_execs)); } else { - *arrays[i] = arrays[mem_plan[i].sid]->AsArray(shapes[i], dtypes[i]); - if (mem_plan[i].inplace && array_reqs->at(i) == kWriteTo) { - array_reqs->at(i) = kWriteInplace; - } + seg = EngineOprSeg{true, nid, nullptr}; } + seg_start = nid; + seg_execs.clear(); + } + + seg_execs.push_back(exec); + + auto& seg = (*opr_segs)[nid]; + if (is_async) { + seg = EngineOprSeg{false, nid + 1}; + seg.opr.reset(CreateEngineOp(default_ctx, seg_execs)); + seg_execs.clear(); + seg_start = nid + 1; + } else if (!valid) { + seg = EngineOprSeg{false, nid + 1, nullptr}; + seg_execs.clear(); + seg_start = nid + 1; + } + } + // The last segment + if (end_nid > seg_start) { + auto& seg = (*opr_segs)[seg_start]; + if (seg_execs.size()) { + seg = EngineOprSeg{false, end_nid}; + seg.opr.reset(CreateEngineOp(default_ctx, seg_execs)); } else { - *arrays[i] = NDArray(static_cast(stypes[i]), - shapes[i], default_ctx, true, dtypes[i]); + seg = EngineOprSeg{true, end_nid, nullptr}; } } } + +void RunGraph(const bool retain_graph, + const nnvm::IndexedGraph& idx, + const std::vector arrays, + size_t node_start, size_t node_end, + std::vector&& array_reqs, + std::vector&& ref_count, + std::vector *p_states, + const DispatchModeVector &dispatch_modes); + } // namespace imperative } // namespace mxnet diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 451fde2eb867..5701a5df5a08 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -22,6 +22,7 @@ from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID from common import setup_module, with_seed, assertRaises, teardown import numpy as np +from numpy.testing import assert_array_equal from nose.tools import raises, assert_raises from copy import deepcopy import warnings @@ -1124,7 +1125,6 @@ def test_hybrid_multi_context(): net.hybridize() net(mx.nd.zeros((1, 3, 32, 32), ctx=mx.cpu(0))).asnumpy() - @with_seed() def test_zero_grad(): data = mx.nd.random.uniform(shape=(3,3)) @@ -1137,6 +1137,60 @@ def test_zero_grad(): grad = net.collect_params()['test_zero_grad_weight'].grad() assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0) +def check_hybrid_static_memory(**kwargs): + x = mx.nd.random.uniform(shape=(2, 3, 32, 32)) + x.attach_grad() + + net1 = gluon.model_zoo.vision.get_resnet( + 1, 18, pretrained=True, prefix='net_', ctx=mx.context.current_context()) + net2 = gluon.model_zoo.vision.get_resnet( + 1, 18, pretrained=True, prefix='net_', ctx=mx.context.current_context()) + net2.hybridize(**kwargs) + net1(x) + net2(x) + + def test(net, x): + with mx.autograd.record(): + y = net(x) + net(x) + y.backward() + + grads = {k: v.grad() for k, v in net.collect_params().items() if v.grad_req != 'null'} + + return y, grads + + y1, grads1 = test(net1, x) + y2, grads2 = test(net2, x) + + assert_almost_equal(y1.asnumpy(), y2.asnumpy(), rtol=1e-3, atol=1e-5) + for key in grads1: + assert_almost_equal(grads1[key].asnumpy(), grads2[key].asnumpy(), rtol=1e-3, atol=1e-5) + +def test_hybrid_static_memory(): + check_hybrid_static_memory() + check_hybrid_static_memory(static_alloc=True) + check_hybrid_static_memory(static_alloc=True, static_shape=True) + +def check_hybrid_static_memory_switching(**kwargs): + net = gluon.model_zoo.vision.get_resnet( + 1, 18, pretrained=True, ctx=mx.context.current_context()) + net.hybridize(**kwargs) + + x = mx.nd.random.uniform(shape=(4, 3, 32, 32)) + net(x) + with mx.autograd.record(): + y = net(x) + y.backward() + x = mx.nd.random.uniform(shape=(2, 3, 32, 32)) + net(x) + with mx.autograd.record(): + y = net(x) + y.backward() + mx.nd.waitall() + +def test_hybrid_static_memory_switching(): + check_hybrid_static_memory_switching() + check_hybrid_static_memory_switching(static_alloc=True) + check_hybrid_static_memory_switching(static_alloc=True, static_shape=True) @with_seed() def test_hook(): @@ -1231,6 +1285,17 @@ def test_legacy_save_params(): model.load_params('test.params', ctx=mx.cpu()) +def test_hybrid_static_memory_recording(): + net = gluon.model_zoo.vision.get_resnet( + 1, 18, pretrained=True, ctx=mx.context.current_context()) + net.hybridize(static_alloc=True) + + x = mx.nd.random.uniform(shape=(1, 3, 32, 32)) + with mx.autograd.record(True): + net(x) + net(x) + + if __name__ == '__main__': import nose nose.runmodule()