diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index d26e86f4..88a9f4d5 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -120,6 +120,7 @@ class Imperative { nnvm::Graph fwd_graph_; nnvm::Graph grad_graph_; nnvm::Graph full_graph_; + std::vector ograd_entries_; std::vector curr_grad_req_; std::vector bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_; std::vector bwd_input_eid_; diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 73dbfc10..25467112 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -21,6 +21,7 @@ __all__ = ['Block', 'HybridBlock', 'SymbolBlock'] import copy +import warnings from .. import symbol, ndarray, initializer from ..symbol import Symbol @@ -325,7 +326,7 @@ def __init__(self, prefix=None, params=None): self._reg_params = {} self._cached_graph = () self._cached_op = None - self._cached_params = None + self._cached_op_args = None self._out_format = None self._in_format = None self._active = False @@ -363,34 +364,47 @@ def _get_graph(self, *args): def _build_cache(self, *args): inputs, out = self._get_graph(*args) + input_idx = {var.name: i for i, var in enumerate(inputs)} self._cached_op = ndarray.CachedOp(out) - params = dict(self.collect_params().items()) - self._cached_params = [params.get(name, None) for name in out.list_inputs()] - assert len(params) + len(self._cached_graph[0]) == len(out.list_inputs()), \ - "Wrong number of inputs." - name2pos = {var.name: i for i, var in enumerate(inputs)} - self._in_idx = [(i, name2pos[name]) for i, name in enumerate(out.list_inputs()) - if name not in params] + # verify graph inputs + expected_inputs = set(out.list_inputs()) + for name in expected_inputs: + assert name in params or name in input_idx, \ + "Unknown input to HybridBlock: %s"%name + for name, i in input_idx.items(): + if name not in expected_inputs: + warnings.warn("The %d-th input to HybridBlock is not used by any " + "computation. Is this intended?"%i) + for name in params: + if name not in expected_inputs: + warnings.warn("Parameter %s is not used by any computation. " + "Is this intended?"%name) + + self._cached_op_args = [(False, params[name]) if name in params + else (True, input_idx[name]) + for name in out.list_inputs()] + + def _finish_deferred_init(self, hybrid, *args): + self.infer_shape(*args) + self.infer_type(*args) + if hybrid: + for is_arg, i in self._cached_op_args: + if not is_arg: + i._finish_deferred_init() + else: + for _, i in self.params.items(): + i._finish_deferred_init() def _call_cached_op(self, *args): if self._cached_op is None: self._build_cache(*args) - try: - cargs = [i.data() if i else None for i in self._cached_params] - except DeferredInitializationError: - self.infer_shape(*args) - for i in self._cached_params: - if i is not None: - i._finish_deferred_init() - cargs = [i.data() if i else None for i in self._cached_params] - args, fmt = _flatten(args) assert fmt == self._in_format, "Invalid input format" - for i, j in self._in_idx: - cargs[i] = args[j] + cargs = [args[i] if is_arg else i.data() + for is_arg, i in self._cached_op_args] out = self._cached_op(*cargs) if isinstance(out, NDArray): out = [out] @@ -399,6 +413,7 @@ def _call_cached_op(self, *args): def _clear_cached_op(self): self._cached_graph = () self._cached_op = None + self._cached_op_args = None def register_child(self, block): if not isinstance(block, HybridBlock): @@ -414,17 +429,25 @@ def hybridize(self, active=True): self._active = active super(HybridBlock, self).hybridize(active) - def infer_shape(self, *args): - """Infers shape of Parameters from inputs.""" + def _infer_attrs(self, infer_fn, attr, *args): + """Generic infer attributes.""" inputs, out = self._get_graph(*args) args, _ = _flatten(args) - arg_shapes, _, aux_shapes = out.infer_shape( - **{i.name: j.shape for i, j in zip(inputs, args)}) - sdict = {i: j for i, j in zip(out.list_arguments(), arg_shapes)} - sdict.update({name : shape for name, shape in \ - zip(out.list_auxiliary_states(), aux_shapes)}) + arg_attrs, _, aux_attrs = getattr(out, infer_fn)( + **{i.name: getattr(j, attr) for i, j in zip(inputs, args)}) + sdict = {i: j for i, j in zip(out.list_arguments(), arg_attrs)} + sdict.update({name : attr for name, attr in \ + zip(out.list_auxiliary_states(), aux_attrs)}) for i in self.collect_params().values(): - i.shape = sdict[i.name] + setattr(i, attr, sdict[i.name]) + + def infer_shape(self, *args): + """Infers shape of Parameters from inputs.""" + self._infer_attrs('infer_shape', 'shape', *args) + + def infer_type(self, *args): + """Infers data type of Parameters from inputs.""" + self._infer_attrs('infer_type', 'dtype', *args) def export(self, path): """Export HybridBlock to json format that can be loaded by `mxnet.mod.Module` @@ -462,15 +485,16 @@ def forward(self, x, *args): :py:class:`NDArray` or :py:class:`Symbol`.""" if isinstance(x, NDArray): with x.context as ctx: - if self._active: - return self._call_cached_op(x, *args) try: + if self._active: + return self._call_cached_op(x, *args) params = {i: j.data(ctx) for i, j in self._reg_params.items()} except DeferredInitializationError: - self.infer_shape(x, *args) - for i in self.collect_params().values(): - i._finish_deferred_init() - params = {i: j.data(ctx) for i, j in self._reg_params.items()} + self._finish_deferred_init(self._active, x, *args) + + if self._active: + return self._call_cached_op(x, *args) + params = {i: j.data(ctx) for i, j in self._reg_params.items()} return self.hybrid_forward(ndarray, x, *args, **params) assert isinstance(x, Symbol), \ @@ -559,6 +583,11 @@ def __init__(self, outputs, inputs, params=None): def forward(self, x, *args): if isinstance(x, NDArray): with x.context: + try: + return self._call_cached_op(x, *args) + except DeferredInitializationError: + self._finish_deferred_init(True, x, *args) + return self._call_cached_op(x, *args) assert isinstance(x, Symbol), \ diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index 906f03ec..8034ab84 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -185,11 +185,11 @@ def __init__(self, units, activation=None, use_bias=True, flatten=True, self._units = units self._in_units = in_units self.weight = self.params.get('weight', shape=(units, in_units), - init=weight_initializer, + dtype=None, init=weight_initializer, allow_deferred_init=True) if use_bias: self.bias = self.params.get('bias', shape=(units,), - init=bias_initializer, + dtype=None, init=bias_initializer, allow_deferred_init=True) else: self.bias = None @@ -336,20 +336,20 @@ def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True, self.in_channels = in_channels self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null', - shape=(in_channels,), init=gamma_initializer, - allow_deferred_init=True, + shape=(in_channels,), dtype=None, + init=gamma_initializer, allow_deferred_init=True, differentiable=scale) self.beta = self.params.get('beta', grad_req='write' if center else 'null', - shape=(in_channels,), init=beta_initializer, - allow_deferred_init=True, + shape=(in_channels,), dtype=None, + init=beta_initializer, allow_deferred_init=True, differentiable=center) self.running_mean = self.params.get('running_mean', grad_req='null', - shape=(in_channels,), + shape=(in_channels,), dtype=None, init=running_mean_initializer, allow_deferred_init=True, differentiable=False) self.running_var = self.params.get('running_var', grad_req='null', - shape=(in_channels,), + shape=(in_channels,), dtype=None, init=running_variance_initializer, allow_deferred_init=True, differentiable=False) @@ -437,7 +437,7 @@ def __init__(self, input_dim, output_dim, dtype='float32', self._kwargs = {'input_dim': input_dim, 'output_dim': output_dim, 'dtype': dtype} self.weight = self.params.get('weight', shape=(input_dim, output_dim), - init=weight_initializer, + dtype=None, init=weight_initializer, allow_deferred_init=True) def hybrid_forward(self, F, x, weight): diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py index 645de98e..0dd70697 100644 --- a/python/mxnet/gluon/nn/conv_layers.py +++ b/python/mxnet/gluon/nn/conv_layers.py @@ -113,11 +113,11 @@ def __init__(self, channels, kernel_size, strides, padding, dilation, dshape[layout.find('C')] = in_channels wshapes = _infer_weight_shape(op_name, dshape, self._kwargs) self.weight = self.params.get('weight', shape=wshapes[1], - init=weight_initializer, + dtype=None, init=weight_initializer, allow_deferred_init=True) if use_bias: self.bias = self.params.get('bias', shape=wshapes[2], - init=bias_initializer, + dtype=None, init=bias_initializer, allow_deferred_init=True) else: self.bias = None diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index c42fbaa1..27297b53 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -306,7 +306,7 @@ def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(), ctx = [ctx] if init is None: init = default_init if self.init is None else self.init - if not self.shape or np.prod(self.shape) <= 0: + if self.dtype is None or not self.shape or np.prod(self.shape) <= 0: if self._allow_deferred_init: self._deferred_init = (init, ctx, default_init) return diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py index ea0e32fa..80bb8e3f 100644 --- a/python/mxnet/gluon/rnn/rnn_cell.py +++ b/python/mxnet/gluon/rnn/rnn_cell.py @@ -326,16 +326,16 @@ def __init__(self, hidden_size, activation='tanh', self._activation = activation self._input_size = input_size self.i2h_weight = self.params.get('i2h_weight', shape=(hidden_size, input_size), - init=i2h_weight_initializer, + dtype=None, init=i2h_weight_initializer, allow_deferred_init=True) self.h2h_weight = self.params.get('h2h_weight', shape=(hidden_size, hidden_size), - init=h2h_weight_initializer, + dtype=None, init=h2h_weight_initializer, allow_deferred_init=True) self.i2h_bias = self.params.get('i2h_bias', shape=(hidden_size,), - init=i2h_bias_initializer, + dtype=None, init=i2h_bias_initializer, allow_deferred_init=True) self.h2h_bias = self.params.get('h2h_bias', shape=(hidden_size,), - init=h2h_bias_initializer, + dtype=None, init=h2h_bias_initializer, allow_deferred_init=True) def state_info(self, batch_size=0): @@ -434,16 +434,16 @@ def __init__(self, hidden_size, self._hidden_size = hidden_size self._input_size = input_size self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, input_size), - init=i2h_weight_initializer, + dtype=None, init=i2h_weight_initializer, allow_deferred_init=True) self.h2h_weight = self.params.get('h2h_weight', shape=(4*hidden_size, hidden_size), - init=h2h_weight_initializer, + dtype=None, init=h2h_weight_initializer, allow_deferred_init=True) self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,), - init=i2h_bias_initializer, + dtype=None, init=i2h_bias_initializer, allow_deferred_init=True) self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,), - init=h2h_bias_initializer, + dtype=None, init=h2h_bias_initializer, allow_deferred_init=True) def state_info(self, batch_size=0): @@ -541,16 +541,16 @@ def __init__(self, hidden_size, self._hidden_size = hidden_size self._input_size = input_size self.i2h_weight = self.params.get('i2h_weight', shape=(3*hidden_size, input_size), - init=i2h_weight_initializer, + dtype=None, init=i2h_weight_initializer, allow_deferred_init=True) self.h2h_weight = self.params.get('h2h_weight', shape=(3*hidden_size, hidden_size), - init=h2h_weight_initializer, + dtype=None, init=h2h_weight_initializer, allow_deferred_init=True) self.i2h_bias = self.params.get('i2h_bias', shape=(3*hidden_size,), - init=i2h_bias_initializer, + dtype=None, init=i2h_bias_initializer, allow_deferred_init=True) self.h2h_bias = self.params.get('h2h_bias', shape=(3*hidden_size,), - init=h2h_bias_initializer, + dtype=None, init=h2h_bias_initializer, allow_deferred_init=True) def state_info(self, batch_size=0): diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index eb99aabf..60d66db4 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -62,11 +62,10 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) { } // construct backward graph - std::vector ograd_entries; { - ograd_entries.reserve(fwd_graph_.outputs.size()); + ograd_entries_.reserve(fwd_graph_.outputs.size()); for (size_t i = 0; i < fwd_graph_.outputs.size(); ++i) { - ograd_entries.emplace_back(NodeEntry{Node::Create(), 0, 0}); + ograd_entries_.emplace_back(NodeEntry{Node::Create(), 0, 0}); } std::vector xs; @@ -77,7 +76,7 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) { << "There are no inputs in computation graph that require gradients."; grad_graph_ = pass::Gradient( - fwd_graph_, fwd_graph_.outputs, xs, ograd_entries, + fwd_graph_, fwd_graph_.outputs, xs, ograd_entries_, exec::AggregateGradient, nullptr, nullptr, zero_ops, "_copy"); } @@ -105,12 +104,12 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) { std::make_shared(std::move(full_ref_count)); size_t num_forward_inputs = num_inputs(); - for (uint32_t i = 0; i < ograd_entries.size(); ++i) { - if (!idx.exist(ograd_entries[i].node.get())) continue; - auto eid = idx.entry_id(ograd_entries[i]); + size_t num_forward_outputs = num_outputs(); + for (uint32_t i = 0; i < ograd_entries_.size(); ++i) { + if (!idx.exist(ograd_entries_[i].node.get())) continue; + auto eid = idx.entry_id(ograd_entries_[i]); if (ref_count[eid] > 0) { bwd_ograd_dep_.push_back(i); - bwd_input_eid_.push_back(eid); } } save_inputs_.resize(num_forward_inputs, false); @@ -119,16 +118,14 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) { if (ref_count[eid] > 0) { save_inputs_[i] = true; bwd_in_dep_.push_back(i); - bwd_input_eid_.push_back(eid); } } save_outputs_.resize(idx.outputs().size(), false); - for (uint32_t i = 0; i < idx.outputs().size(); ++i) { + for (uint32_t i = 0; i < num_forward_outputs; ++i) { auto eid = idx.entry_id(idx.outputs()[i]); if (ref_count[eid] > 0) { save_outputs_[i] = true; bwd_out_dep_.push_back(i); - bwd_input_eid_.push_back(eid); } } } @@ -242,9 +239,28 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph( for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) { if (curr_grad_req_[i]) g.outputs.emplace_back(grad_graph_.outputs[i]); } + bwd_input_eid_.clear(); } const auto& idx = g.indexed_graph(); + + if (bwd_input_eid_.size() != inputs.size()) { + bwd_input_eid_.clear(); + for (const auto& i : bwd_ograd_dep_) { + auto eid = idx.entry_id(ograd_entries_[i]); + bwd_input_eid_.push_back(eid); + } + for (const auto& i : bwd_in_dep_) { + auto eid = idx.entry_id(idx.input_nodes()[i], 0); + bwd_input_eid_.push_back(eid); + } + for (const auto& i : bwd_out_dep_) { + auto eid = idx.entry_id(idx.outputs()[i]); + bwd_input_eid_.push_back(eid); + } + CHECK_EQ(inputs.size(), bwd_input_eid_.size()); + } + size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes(); size_t num_forward_entries = fwd_graph_.indexed_graph().num_node_entries(); diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index df0af34d..751f1fbd 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -567,6 +567,17 @@ def test_fill_shape_deferred(): assert net[2].weight.shape[1] == 3072, net[2].weight.shape[1] +def test_dtype(): + net = mx.gluon.model_zoo.vision.resnet18_v1() + net.initialize() + net(mx.nd.ones((16, 3, 32, 32), dtype='float64')).wait_to_read() + + net = mx.gluon.model_zoo.vision.resnet18_v1() + net.initialize() + net.hybridize() + net(mx.nd.ones((16, 3, 32, 32), dtype='float64')).wait_to_read() + + def test_fill_shape_load(): ctx = mx.context.current_context() net1 = nn.HybridSequential()