Skip to content

Commit

Permalink
Static alloc for hybridblock (apache#11313)
Browse files Browse the repository at this point in the history
Thanks for the contribution, this is now merged
  • Loading branch information
piiswrong authored and tqchen committed Jun 16, 2018
1 parent 02e8a71 commit 5431e12
Show file tree
Hide file tree
Showing 18 changed files with 1,399 additions and 523 deletions.
5 changes: 0 additions & 5 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -987,11 +987,6 @@ MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle,
int num_flags,
const char** keys,
const char** vals,
int num_inputs,
const char** input_names,
int num_params,
const char** param_names,
NDArrayHandle* params,
CachedOpHandle *out);
/*!
* \brief free cached operator
Expand Down
89 changes: 0 additions & 89 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,6 @@
#include "./ndarray.h"

namespace mxnet {
/*! \brief CachedOp Parameters */
struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
uint32_t inline_limit;
uint32_t forward_bulk_size;
uint32_t backward_bulk_size;
DMLC_DECLARE_PARAMETER(CachedOpConfig) {
DMLC_DECLARE_FIELD(inline_limit)
.set_default(2)
.describe("Maximum number of operators that can be inlined.");
DMLC_DECLARE_FIELD(forward_bulk_size)
.set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
.describe("Segment size of bulk execution during forward pass.");
DMLC_DECLARE_FIELD(backward_bulk_size)
.set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
.describe("Segment size of bulk execution during backward pass.");
}
};
/*! \brief runtime functions for NDArray */
class Imperative {
public:
Expand Down Expand Up @@ -94,67 +77,6 @@ class Imperative {
&& info.out_grads.size() == 1;
}
};
class CachedOp {
public:
CachedOp(
const nnvm::Symbol& sym,
const std::vector<std::pair<std::string, std::string> >& flags,
const std::vector<std::string> arg_names,
const std::unordered_map<std::string, std::vector<NDArray> >& params);
uint32_t num_inputs() {
return fwd_graph_.indexed_graph().input_nodes().size();
}
uint32_t num_outputs() {
return fwd_graph_.outputs.size();
}
uint32_t num_backward_inputs() {
return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
}
std::vector<bool>& save_inputs() {
return save_inputs_;
}
std::vector<bool>& save_outputs() {
return save_outputs_;
}
const std::unordered_set<uint32_t>& mutable_input_nodes() {
return fwd_graph_.indexed_graph().mutable_input_nodes();
}
nnvm::Graph GetForwardGraph(const bool recording,
const std::vector<NDArray*>& inputs);
nnvm::Graph GetBackwardGraph(const OpStatePtr& state,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& inputs);
std::vector<nnvm::NodeEntry> Gradient(const nnvm::NodePtr& node,
const std::vector<nnvm::NodeEntry>& ograds);
void Forward(const std::shared_ptr<CachedOp>& op_ptr,
const std::vector<NDArray*>& args,
const std::vector<NDArray*>& outputs);
void Backward(const bool retain_graph,
const OpStatePtr& state,
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs);

private:
struct CachedOpState {
std::vector<NDArray> buff;
std::vector<OpStatePtr> states;
};
std::mutex mutex_;
CachedOpConfig config_;
nnvm::Graph fwd_graph_;
nnvm::Graph grad_graph_;
nnvm::Graph full_graph_;
std::unordered_map<Context, std::vector<NDArray> > params_;
bool inlining_;
std::vector<nnvm::NodeEntry> ograd_entries_;
std::vector<bool> curr_grad_req_;
std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
std::vector<uint32_t> fwd_args_idx_;
std::vector<uint32_t> fwd_params_idx_;
std::vector<uint32_t> bwd_input_eid_;
std::vector<bool> save_inputs_, save_outputs_;
};
/*! \brief whether operator recording is on. */
bool is_training() const {
return is_train_;
Expand Down Expand Up @@ -222,15 +144,6 @@ class Imperative {
uint32_t num_inputs, uint32_t num_outputs,
std::vector<bool> *p_save_inputs,
std::vector<bool> *p_save_outputs);
void RunGraph(
const bool retain_graph,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*> arrays,
size_t node_start, size_t node_end,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
std::vector<OpStatePtr> *p_states,
const DispatchModeVector& dispatch_modes);
/*! \brief indicate whether is training. */
#if DMLC_CXX11_THREAD_LOCAL
static thread_local bool is_train_;
Expand All @@ -247,7 +160,5 @@ class Imperative {
int backward_bulk_size_{0};
};

using CachedOpPtr = std::shared_ptr<Imperative::CachedOp>;

} // namespace mxnet
#endif // MXNET_IMPERATIVE_H_
8 changes: 8 additions & 0 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,14 @@ class NDArray {
return byte_offset_ > 0 || shape() != ptr_->storage_shape;
}

/* \brief Check whether the two arrays are the same array */
inline bool IsSame(const NDArray& other) {
return ptr_ == other.ptr_ &&
shape_ == other.shape_ &&
byte_offset_ == other.byte_offset_ &&
dtype_ == other.dtype_;
}

/*!
* \return the shape of current NDArray.
*/
Expand Down
33 changes: 20 additions & 13 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,25 +126,36 @@ class OpStatePtr {
template<typename T, typename... Args>
static OpStatePtr Create(Args&&... args) {
OpStatePtr ret;
ret.ptr_ = std::make_shared<OpState>();
ret.ptr_->var_ = Engine::Get()->NewVariable();
ret.ptr_->state_.construct<T>(std::forward<Args>(args)...);
auto state = new T(std::forward<Args>(args)...);
auto var = Engine::Get()->NewVariable();
ret.ptr_.reset(
new OpState(var, state),
[](OpState* p) {
Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), p->var);
delete reinterpret_cast<T*>(p->state);
delete p;
});

return ret;
}
/* \brief Get engine variable associated with this state */
engine::VarHandle get_var() const {
return ptr_->var_;
return ptr_->var;
}
/* \brief Get state of type T */
template<typename T>
T& get_state() const {
return dmlc::get<T>(ptr_->state_);
return *reinterpret_cast<T*>(ptr_->state);
}
/* \brief clear state */
void reset() {
ptr_.reset();
}
/* \brief checks whether the managed object is managed only by the current
OpStatePtr instance */
bool unique() const {
return ptr_.unique();
}
/* \brief Whether state is empty */
explicit operator bool() const {
return ptr_ ? true : false;
Expand All @@ -153,16 +164,12 @@ class OpStatePtr {
private:
/* \brief state structure */
struct OpState {
OpState() {}
engine::VarHandle var;
void* state;

OpState(engine::VarHandle var_, void* state_) : var(var_), state(state_) {}
OpState(const OpState& other) = delete;
OpState& operator=(const OpState& other) = delete;

~OpState() {
Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), var_);
}

engine::VarHandle var_;
dmlc::any state_;
};
/* \brief shared pointer to state */
std::shared_ptr<OpState> ptr_;
Expand Down
16 changes: 1 addition & 15 deletions python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,28 +105,14 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
class CachedOp(object):
"""Cached operator handle."""
__slots__ = ["handle"]
def __init__(self, sym, flags=(), inputs=None, params=None):
def __init__(self, sym, flags=()):
self.handle = CachedOpHandle()
param_names = []
param_arrays = []
if inputs is None:
assert params is None, "When inputs is None params must also be None."
inputs = sym.list_inputs()
elif params is not None:
for name, arrs in params.items():
param_arrays.extend(arrs)
param_names.extend([name] * len(arrs))

check_call(_LIB.MXCreateCachedOpEx(
sym.handle,
len(flags),
c_str_array([key for key, _ in flags]),
c_str_array([str(val) for _, val in flags]),
len(inputs),
c_str_array(inputs),
len(param_names),
c_str_array(param_names),
c_handle_array(param_arrays),
ctypes.byref(self.handle)))

def __del__(self):
Expand Down
74 changes: 49 additions & 25 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,16 @@ def hybridize(self, active=True, **kwargs):
----------
active : bool, default True
Whether to turn hybrid on or off.
**kwargs : string
Additional flags for hybridized operator.
static_alloc : bool, default False
Statically allocate memory to improve speed. Memory usage may increase.
static_shape : bool, default False
Optimize for invariant input shapes between iterations. Must also
set static_alloc to True. Change of input shapes is still allowed
but slower.
forward_bulk_size : int, default 15
Segment size of bulk execution during forward pass.
backward_bulk_size : int, default 15
Segment size of bulk execution during backward pass.
"""
for cld in self._children.values():
cld.hybridize(active, **kwargs)
Expand Down Expand Up @@ -696,7 +704,7 @@ def __init__(self, prefix=None, params=None):
self._out_format = None
self._in_format = None
self._active = False
self._flags = {}
self._flags = []

def __setattr__(self, name, value):
"""Registers parameters."""
Expand All @@ -723,39 +731,43 @@ def _get_graph(self, *args):
return self._cached_graph

def _build_cache(self, *args):
inputs, out = self._get_graph(*args)
input_names = [i.name for i in inputs]

data, out = self._get_graph(*args)
data_names = {data.name : i for i, data in enumerate(data)}
params = self.collect_params()
input_names = out.list_inputs()

param_names = set(params.keys())
expected_names = set(out.list_inputs())
expected_names = set(input_names)
for name in expected_names:
assert name in param_names or name in input_names, \
assert name in param_names or name in data_names, \
"Unknown input to HybridBlock: %s"%name

used_input_names = [i for i in input_names if i in expected_names]
if len(used_input_names) != len(input_names):
unused = ', '.join(['%d-th'%i for i, name in enumerate(input_names)
used_data_names = [i for i in data_names if i in expected_names]
if len(used_data_names) != len(data_names):
unused = ', '.join(['%d-th'%i for name, i in data_names.items()
if name not in expected_names])
warnings.warn("The %s input to HybridBlock is not used by any "
"computation. Is this intended?"%unused, stacklevel=4)

used_param_names = set(i for i in param_names if i in expected_names)
used_param_names = [i for i in param_names if i in expected_names]
if len(used_param_names) != len(param_names):
unused = ', '.join(list(param_names - used_param_names))
unused = ', '.join(list(param_names - set(used_param_names)))
warnings.warn("Parameter %s is not used by any computation. "
"Is this intended?"%unused, stacklevel=4)

used_params = {k: params[k] for k in used_param_names}
try:
param_dict = {k: v.list_data() for k, v in used_params.items()}
except DeferredInitializationError:
self._deferred_infer_shape(*args)
for i in used_params.values():
i._finish_deferred_init()
param_dict = {k: v.list_data() for k, v in used_params.items()}

self._cached_op = ndarray.CachedOp(out, self._flags, input_names, param_dict)
data_indices = []
param_indices = []
self._cached_op_args = []
for i, name in enumerate(input_names):
if name in data_names:
data_indices.append(i)
self._cached_op_args.append((True, data_names[name]))
else:
param_indices.append(i)
self._cached_op_args.append((False, params[name]))
flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
self._flags
self._cached_op = ndarray.CachedOp(out, flags)

def _deferred_infer_shape(self, *args):
try:
Expand All @@ -771,7 +783,19 @@ def _call_cached_op(self, *args):

args, fmt = _flatten(args, "input")
assert fmt == self._in_format, "Invalid input format"
out = self._cached_op(*args)
try:
cargs = [args[i] if is_arg else i.data()
for is_arg, i in self._cached_op_args]
except DeferredInitializationError:
self._deferred_infer_shape(*args)
cargs = []
for is_arg, i in self._cached_op_args:
if is_arg:
cargs.append(args[i])
else:
i._finish_deferred_init()
cargs.append(i.data())
out = self._cached_op(*cargs)
if isinstance(out, NDArray):
out = [out]
return _regroup(out, self._out_format)[0]
Expand All @@ -792,7 +816,7 @@ def register_child(self, block, name=None):

def hybridize(self, active=True, **kwargs):
self._active = active
self._flags = kwargs.items()
self._flags = list(kwargs.items())
self._clear_cached_op()
if active and self._forward_hooks or self._forward_pre_hooks:
warnings.warn('"{}" is being hybridized while still having forward hook/pre-hook. '
Expand Down
Loading

0 comments on commit 5431e12

Please sign in to comment.