Skip to content

Commit b604453

Browse files
marcoabreuzheng-da
authored andcommitted
Revert "Static alloc for hybridblock (apache#11313)" (apache#11318)
This reverts commit 5431e12.
1 parent 68489b6 commit b604453

18 files changed

+523
-1399
lines changed

include/mxnet/c_api.h

+5
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,11 @@ MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle,
987987
int num_flags,
988988
const char** keys,
989989
const char** vals,
990+
int num_inputs,
991+
const char** input_names,
992+
int num_params,
993+
const char** param_names,
994+
NDArrayHandle* params,
990995
CachedOpHandle *out);
991996
/*!
992997
* \brief free cached operator

include/mxnet/imperative.h

+89
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,23 @@
3535
#include "./ndarray.h"
3636

3737
namespace mxnet {
38+
/*! \brief CachedOp Parameters */
39+
struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
40+
uint32_t inline_limit;
41+
uint32_t forward_bulk_size;
42+
uint32_t backward_bulk_size;
43+
DMLC_DECLARE_PARAMETER(CachedOpConfig) {
44+
DMLC_DECLARE_FIELD(inline_limit)
45+
.set_default(2)
46+
.describe("Maximum number of operators that can be inlined.");
47+
DMLC_DECLARE_FIELD(forward_bulk_size)
48+
.set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
49+
.describe("Segment size of bulk execution during forward pass.");
50+
DMLC_DECLARE_FIELD(backward_bulk_size)
51+
.set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
52+
.describe("Segment size of bulk execution during backward pass.");
53+
}
54+
};
3855
/*! \brief runtime functions for NDArray */
3956
class Imperative {
4057
public:
@@ -77,6 +94,67 @@ class Imperative {
7794
&& info.out_grads.size() == 1;
7895
}
7996
};
97+
class CachedOp {
98+
public:
99+
CachedOp(
100+
const nnvm::Symbol& sym,
101+
const std::vector<std::pair<std::string, std::string> >& flags,
102+
const std::vector<std::string> arg_names,
103+
const std::unordered_map<std::string, std::vector<NDArray> >& params);
104+
uint32_t num_inputs() {
105+
return fwd_graph_.indexed_graph().input_nodes().size();
106+
}
107+
uint32_t num_outputs() {
108+
return fwd_graph_.outputs.size();
109+
}
110+
uint32_t num_backward_inputs() {
111+
return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
112+
}
113+
std::vector<bool>& save_inputs() {
114+
return save_inputs_;
115+
}
116+
std::vector<bool>& save_outputs() {
117+
return save_outputs_;
118+
}
119+
const std::unordered_set<uint32_t>& mutable_input_nodes() {
120+
return fwd_graph_.indexed_graph().mutable_input_nodes();
121+
}
122+
nnvm::Graph GetForwardGraph(const bool recording,
123+
const std::vector<NDArray*>& inputs);
124+
nnvm::Graph GetBackwardGraph(const OpStatePtr& state,
125+
const std::vector<OpReqType>& reqs,
126+
const std::vector<NDArray*>& inputs);
127+
std::vector<nnvm::NodeEntry> Gradient(const nnvm::NodePtr& node,
128+
const std::vector<nnvm::NodeEntry>& ograds);
129+
void Forward(const std::shared_ptr<CachedOp>& op_ptr,
130+
const std::vector<NDArray*>& args,
131+
const std::vector<NDArray*>& outputs);
132+
void Backward(const bool retain_graph,
133+
const OpStatePtr& state,
134+
const std::vector<NDArray*>& inputs,
135+
const std::vector<OpReqType>& reqs,
136+
const std::vector<NDArray*>& outputs);
137+
138+
private:
139+
struct CachedOpState {
140+
std::vector<NDArray> buff;
141+
std::vector<OpStatePtr> states;
142+
};
143+
std::mutex mutex_;
144+
CachedOpConfig config_;
145+
nnvm::Graph fwd_graph_;
146+
nnvm::Graph grad_graph_;
147+
nnvm::Graph full_graph_;
148+
std::unordered_map<Context, std::vector<NDArray> > params_;
149+
bool inlining_;
150+
std::vector<nnvm::NodeEntry> ograd_entries_;
151+
std::vector<bool> curr_grad_req_;
152+
std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
153+
std::vector<uint32_t> fwd_args_idx_;
154+
std::vector<uint32_t> fwd_params_idx_;
155+
std::vector<uint32_t> bwd_input_eid_;
156+
std::vector<bool> save_inputs_, save_outputs_;
157+
};
80158
/*! \brief whether operator recording is on. */
81159
bool is_training() const {
82160
return is_train_;
@@ -144,6 +222,15 @@ class Imperative {
144222
uint32_t num_inputs, uint32_t num_outputs,
145223
std::vector<bool> *p_save_inputs,
146224
std::vector<bool> *p_save_outputs);
225+
void RunGraph(
226+
const bool retain_graph,
227+
const nnvm::IndexedGraph& idx,
228+
const std::vector<NDArray*> arrays,
229+
size_t node_start, size_t node_end,
230+
std::vector<OpReqType>&& array_reqs,
231+
std::vector<uint32_t>&& ref_count,
232+
std::vector<OpStatePtr> *p_states,
233+
const DispatchModeVector& dispatch_modes);
147234
/*! \brief indicate whether is training. */
148235
#if DMLC_CXX11_THREAD_LOCAL
149236
static thread_local bool is_train_;
@@ -160,5 +247,7 @@ class Imperative {
160247
int backward_bulk_size_{0};
161248
};
162249

250+
using CachedOpPtr = std::shared_ptr<Imperative::CachedOp>;
251+
163252
} // namespace mxnet
164253
#endif // MXNET_IMPERATIVE_H_

include/mxnet/ndarray.h

-8
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,6 @@ class NDArray {
155155
return byte_offset_ > 0 || shape() != ptr_->storage_shape;
156156
}
157157

158-
/* \brief Check whether the two arrays are the same array */
159-
inline bool IsSame(const NDArray& other) {
160-
return ptr_ == other.ptr_ &&
161-
shape_ == other.shape_ &&
162-
byte_offset_ == other.byte_offset_ &&
163-
dtype_ == other.dtype_;
164-
}
165-
166158
/*!
167159
* \return the shape of current NDArray.
168160
*/

include/mxnet/op_attr_types.h

+13-20
Original file line numberDiff line numberDiff line change
@@ -126,36 +126,25 @@ class OpStatePtr {
126126
template<typename T, typename... Args>
127127
static OpStatePtr Create(Args&&... args) {
128128
OpStatePtr ret;
129-
auto state = new T(std::forward<Args>(args)...);
130-
auto var = Engine::Get()->NewVariable();
131-
ret.ptr_.reset(
132-
new OpState(var, state),
133-
[](OpState* p) {
134-
Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), p->var);
135-
delete reinterpret_cast<T*>(p->state);
136-
delete p;
137-
});
129+
ret.ptr_ = std::make_shared<OpState>();
130+
ret.ptr_->var_ = Engine::Get()->NewVariable();
131+
ret.ptr_->state_.construct<T>(std::forward<Args>(args)...);
138132

139133
return ret;
140134
}
141135
/* \brief Get engine variable associated with this state */
142136
engine::VarHandle get_var() const {
143-
return ptr_->var;
137+
return ptr_->var_;
144138
}
145139
/* \brief Get state of type T */
146140
template<typename T>
147141
T& get_state() const {
148-
return *reinterpret_cast<T*>(ptr_->state);
142+
return dmlc::get<T>(ptr_->state_);
149143
}
150144
/* \brief clear state */
151145
void reset() {
152146
ptr_.reset();
153147
}
154-
/* \brief checks whether the managed object is managed only by the current
155-
OpStatePtr instance */
156-
bool unique() const {
157-
return ptr_.unique();
158-
}
159148
/* \brief Whether state is empty */
160149
explicit operator bool() const {
161150
return ptr_ ? true : false;
@@ -164,12 +153,16 @@ class OpStatePtr {
164153
private:
165154
/* \brief state structure */
166155
struct OpState {
167-
engine::VarHandle var;
168-
void* state;
169-
170-
OpState(engine::VarHandle var_, void* state_) : var(var_), state(state_) {}
156+
OpState() {}
171157
OpState(const OpState& other) = delete;
172158
OpState& operator=(const OpState& other) = delete;
159+
160+
~OpState() {
161+
Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), var_);
162+
}
163+
164+
engine::VarHandle var_;
165+
dmlc::any state_;
173166
};
174167
/* \brief shared pointer to state */
175168
std::shared_ptr<OpState> ptr_;

python/mxnet/_ctypes/ndarray.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,28 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
105105
class CachedOp(object):
106106
"""Cached operator handle."""
107107
__slots__ = ["handle"]
108-
def __init__(self, sym, flags=()):
108+
def __init__(self, sym, flags=(), inputs=None, params=None):
109109
self.handle = CachedOpHandle()
110+
param_names = []
111+
param_arrays = []
112+
if inputs is None:
113+
assert params is None, "When inputs is None params must also be None."
114+
inputs = sym.list_inputs()
115+
elif params is not None:
116+
for name, arrs in params.items():
117+
param_arrays.extend(arrs)
118+
param_names.extend([name] * len(arrs))
110119

111120
check_call(_LIB.MXCreateCachedOpEx(
112121
sym.handle,
113122
len(flags),
114123
c_str_array([key for key, _ in flags]),
115124
c_str_array([str(val) for _, val in flags]),
125+
len(inputs),
126+
c_str_array(inputs),
127+
len(param_names),
128+
c_str_array(param_names),
129+
c_handle_array(param_arrays),
116130
ctypes.byref(self.handle)))
117131

118132
def __del__(self):

python/mxnet/gluon/block.py

+25-49
Original file line numberDiff line numberDiff line change
@@ -502,16 +502,8 @@ def hybridize(self, active=True, **kwargs):
502502
----------
503503
active : bool, default True
504504
Whether to turn hybrid on or off.
505-
static_alloc : bool, default False
506-
Statically allocate memory to improve speed. Memory usage may increase.
507-
static_shape : bool, default False
508-
Optimize for invariant input shapes between iterations. Must also
509-
set static_alloc to True. Change of input shapes is still allowed
510-
but slower.
511-
forward_bulk_size : int, default 15
512-
Segment size of bulk execution during forward pass.
513-
backward_bulk_size : int, default 15
514-
Segment size of bulk execution during backward pass.
505+
**kwargs : string
506+
Additional flags for hybridized operator.
515507
"""
516508
for cld in self._children.values():
517509
cld.hybridize(active, **kwargs)
@@ -704,7 +696,7 @@ def __init__(self, prefix=None, params=None):
704696
self._out_format = None
705697
self._in_format = None
706698
self._active = False
707-
self._flags = []
699+
self._flags = {}
708700

709701
def __setattr__(self, name, value):
710702
"""Registers parameters."""
@@ -731,43 +723,39 @@ def _get_graph(self, *args):
731723
return self._cached_graph
732724

733725
def _build_cache(self, *args):
734-
data, out = self._get_graph(*args)
735-
data_names = {data.name : i for i, data in enumerate(data)}
736-
params = self.collect_params()
737-
input_names = out.list_inputs()
726+
inputs, out = self._get_graph(*args)
727+
input_names = [i.name for i in inputs]
738728

729+
params = self.collect_params()
739730
param_names = set(params.keys())
740-
expected_names = set(input_names)
731+
expected_names = set(out.list_inputs())
741732
for name in expected_names:
742-
assert name in param_names or name in data_names, \
733+
assert name in param_names or name in input_names, \
743734
"Unknown input to HybridBlock: %s"%name
744735

745-
used_data_names = [i for i in data_names if i in expected_names]
746-
if len(used_data_names) != len(data_names):
747-
unused = ', '.join(['%d-th'%i for name, i in data_names.items()
736+
used_input_names = [i for i in input_names if i in expected_names]
737+
if len(used_input_names) != len(input_names):
738+
unused = ', '.join(['%d-th'%i for i, name in enumerate(input_names)
748739
if name not in expected_names])
749740
warnings.warn("The %s input to HybridBlock is not used by any "
750741
"computation. Is this intended?"%unused, stacklevel=4)
751742

752-
used_param_names = [i for i in param_names if i in expected_names]
743+
used_param_names = set(i for i in param_names if i in expected_names)
753744
if len(used_param_names) != len(param_names):
754-
unused = ', '.join(list(param_names - set(used_param_names)))
745+
unused = ', '.join(list(param_names - used_param_names))
755746
warnings.warn("Parameter %s is not used by any computation. "
756747
"Is this intended?"%unused, stacklevel=4)
757748

758-
data_indices = []
759-
param_indices = []
760-
self._cached_op_args = []
761-
for i, name in enumerate(input_names):
762-
if name in data_names:
763-
data_indices.append(i)
764-
self._cached_op_args.append((True, data_names[name]))
765-
else:
766-
param_indices.append(i)
767-
self._cached_op_args.append((False, params[name]))
768-
flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
769-
self._flags
770-
self._cached_op = ndarray.CachedOp(out, flags)
749+
used_params = {k: params[k] for k in used_param_names}
750+
try:
751+
param_dict = {k: v.list_data() for k, v in used_params.items()}
752+
except DeferredInitializationError:
753+
self._deferred_infer_shape(*args)
754+
for i in used_params.values():
755+
i._finish_deferred_init()
756+
param_dict = {k: v.list_data() for k, v in used_params.items()}
757+
758+
self._cached_op = ndarray.CachedOp(out, self._flags, input_names, param_dict)
771759

772760
def _deferred_infer_shape(self, *args):
773761
try:
@@ -783,19 +771,7 @@ def _call_cached_op(self, *args):
783771

784772
args, fmt = _flatten(args, "input")
785773
assert fmt == self._in_format, "Invalid input format"
786-
try:
787-
cargs = [args[i] if is_arg else i.data()
788-
for is_arg, i in self._cached_op_args]
789-
except DeferredInitializationError:
790-
self._deferred_infer_shape(*args)
791-
cargs = []
792-
for is_arg, i in self._cached_op_args:
793-
if is_arg:
794-
cargs.append(args[i])
795-
else:
796-
i._finish_deferred_init()
797-
cargs.append(i.data())
798-
out = self._cached_op(*cargs)
774+
out = self._cached_op(*args)
799775
if isinstance(out, NDArray):
800776
out = [out]
801777
return _regroup(out, self._out_format)[0]
@@ -816,7 +792,7 @@ def register_child(self, block, name=None):
816792

817793
def hybridize(self, active=True, **kwargs):
818794
self._active = active
819-
self._flags = list(kwargs.items())
795+
self._flags = kwargs.items()
820796
self._clear_cached_op()
821797
if active and self._forward_hooks or self._forward_pre_hooks:
822798
warnings.warn('"{}" is being hybridized while still having forward hook/pre-hook. '

0 commit comments

Comments
 (0)