Skip to content

Commit

Permalink
gluon with multiple data type (#8522)
Browse files Browse the repository at this point in the history
* gluon with multiple data type

* fix

* fix
  • Loading branch information
piiswrong committed Nov 8, 2017
1 parent 70b68b1 commit 8dc09ec
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 68 deletions.
1 change: 1 addition & 0 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class Imperative {
nnvm::Graph fwd_graph_;
nnvm::Graph grad_graph_;
nnvm::Graph full_graph_;
std::vector<nnvm::NodeEntry> ograd_entries_;
std::vector<bool> curr_grad_req_;
std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
std::vector<uint32_t> bwd_input_eid_;
Expand Down
95 changes: 62 additions & 33 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
__all__ = ['Block', 'HybridBlock', 'SymbolBlock']

import copy
import warnings

from .. import symbol, ndarray, initializer
from ..symbol import Symbol
Expand Down Expand Up @@ -325,7 +326,7 @@ def __init__(self, prefix=None, params=None):
self._reg_params = {}
self._cached_graph = ()
self._cached_op = None
self._cached_params = None
self._cached_op_args = None
self._out_format = None
self._in_format = None
self._active = False
Expand Down Expand Up @@ -363,34 +364,47 @@ def _get_graph(self, *args):

def _build_cache(self, *args):
inputs, out = self._get_graph(*args)
input_idx = {var.name: i for i, var in enumerate(inputs)}
self._cached_op = ndarray.CachedOp(out)

params = dict(self.collect_params().items())
self._cached_params = [params.get(name, None) for name in out.list_inputs()]
assert len(params) + len(self._cached_graph[0]) == len(out.list_inputs()), \
"Wrong number of inputs."

name2pos = {var.name: i for i, var in enumerate(inputs)}
self._in_idx = [(i, name2pos[name]) for i, name in enumerate(out.list_inputs())
if name not in params]
# verify graph inputs
expected_inputs = set(out.list_inputs())
for name in expected_inputs:
assert name in params or name in input_idx, \
"Unknown input to HybridBlock: %s"%name
for name, i in input_idx.items():
if name not in expected_inputs:
warnings.warn("The %d-th input to HybridBlock is not used by any "
"computation. Is this intended?"%i)
for name in params:
if name not in expected_inputs:
warnings.warn("Parameter %s is not used by any computation. "
"Is this intended?"%name)

self._cached_op_args = [(False, params[name]) if name in params
else (True, input_idx[name])
for name in out.list_inputs()]

def _finish_deferred_init(self, hybrid, *args):
self.infer_shape(*args)
self.infer_type(*args)
if hybrid:
for is_arg, i in self._cached_op_args:
if not is_arg:
i._finish_deferred_init()
else:
for _, i in self.params.items():
i._finish_deferred_init()

def _call_cached_op(self, *args):
if self._cached_op is None:
self._build_cache(*args)

try:
cargs = [i.data() if i else None for i in self._cached_params]
except DeferredInitializationError:
self.infer_shape(*args)
for i in self._cached_params:
if i is not None:
i._finish_deferred_init()
cargs = [i.data() if i else None for i in self._cached_params]

args, fmt = _flatten(args)
assert fmt == self._in_format, "Invalid input format"
for i, j in self._in_idx:
cargs[i] = args[j]
cargs = [args[i] if is_arg else i.data()
for is_arg, i in self._cached_op_args]
out = self._cached_op(*cargs)
if isinstance(out, NDArray):
out = [out]
Expand All @@ -399,6 +413,7 @@ def _call_cached_op(self, *args):
def _clear_cached_op(self):
self._cached_graph = ()
self._cached_op = None
self._cached_op_args = None

def register_child(self, block):
if not isinstance(block, HybridBlock):
Expand All @@ -414,17 +429,25 @@ def hybridize(self, active=True):
self._active = active
super(HybridBlock, self).hybridize(active)

def infer_shape(self, *args):
"""Infers shape of Parameters from inputs."""
def _infer_attrs(self, infer_fn, attr, *args):
"""Generic infer attributes."""
inputs, out = self._get_graph(*args)
args, _ = _flatten(args)
arg_shapes, _, aux_shapes = out.infer_shape(
**{i.name: j.shape for i, j in zip(inputs, args)})
sdict = {i: j for i, j in zip(out.list_arguments(), arg_shapes)}
sdict.update({name : shape for name, shape in \
zip(out.list_auxiliary_states(), aux_shapes)})
arg_attrs, _, aux_attrs = getattr(out, infer_fn)(
**{i.name: getattr(j, attr) for i, j in zip(inputs, args)})
sdict = {i: j for i, j in zip(out.list_arguments(), arg_attrs)}
sdict.update({name : attr for name, attr in \
zip(out.list_auxiliary_states(), aux_attrs)})
for i in self.collect_params().values():
i.shape = sdict[i.name]
setattr(i, attr, sdict[i.name])

def infer_shape(self, *args):
"""Infers shape of Parameters from inputs."""
self._infer_attrs('infer_shape', 'shape', *args)

def infer_type(self, *args):
"""Infers data type of Parameters from inputs."""
self._infer_attrs('infer_type', 'dtype', *args)

def export(self, path):
"""Export HybridBlock to json format that can be loaded by `mxnet.mod.Module`
Expand Down Expand Up @@ -462,15 +485,16 @@ def forward(self, x, *args):
:py:class:`NDArray` or :py:class:`Symbol`."""
if isinstance(x, NDArray):
with x.context as ctx:
if self._active:
return self._call_cached_op(x, *args)
try:
if self._active:
return self._call_cached_op(x, *args)
params = {i: j.data(ctx) for i, j in self._reg_params.items()}
except DeferredInitializationError:
self.infer_shape(x, *args)
for i in self.collect_params().values():
i._finish_deferred_init()
params = {i: j.data(ctx) for i, j in self._reg_params.items()}
self._finish_deferred_init(self._active, x, *args)

if self._active:
return self._call_cached_op(x, *args)
params = {i: j.data(ctx) for i, j in self._reg_params.items()}
return self.hybrid_forward(ndarray, x, *args, **params)

assert isinstance(x, Symbol), \
Expand Down Expand Up @@ -559,6 +583,11 @@ def __init__(self, outputs, inputs, params=None):
def forward(self, x, *args):
if isinstance(x, NDArray):
with x.context:
try:
return self._call_cached_op(x, *args)
except DeferredInitializationError:
self._finish_deferred_init(True, x, *args)

return self._call_cached_op(x, *args)

assert isinstance(x, Symbol), \
Expand Down
18 changes: 9 additions & 9 deletions python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,11 @@ def __init__(self, units, activation=None, use_bias=True, flatten=True,
self._units = units
self._in_units = in_units
self.weight = self.params.get('weight', shape=(units, in_units),
init=weight_initializer,
dtype=None, init=weight_initializer,
allow_deferred_init=True)
if use_bias:
self.bias = self.params.get('bias', shape=(units,),
init=bias_initializer,
dtype=None, init=bias_initializer,
allow_deferred_init=True)
else:
self.bias = None
Expand Down Expand Up @@ -336,20 +336,20 @@ def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
self.in_channels = in_channels

self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
shape=(in_channels,), init=gamma_initializer,
allow_deferred_init=True,
shape=(in_channels,), dtype=None,
init=gamma_initializer, allow_deferred_init=True,
differentiable=scale)
self.beta = self.params.get('beta', grad_req='write' if center else 'null',
shape=(in_channels,), init=beta_initializer,
allow_deferred_init=True,
shape=(in_channels,), dtype=None,
init=beta_initializer, allow_deferred_init=True,
differentiable=center)
self.running_mean = self.params.get('running_mean', grad_req='null',
shape=(in_channels,),
shape=(in_channels,), dtype=None,
init=running_mean_initializer,
allow_deferred_init=True,
differentiable=False)
self.running_var = self.params.get('running_var', grad_req='null',
shape=(in_channels,),
shape=(in_channels,), dtype=None,
init=running_variance_initializer,
allow_deferred_init=True,
differentiable=False)
Expand Down Expand Up @@ -437,7 +437,7 @@ def __init__(self, input_dim, output_dim, dtype='float32',
self._kwargs = {'input_dim': input_dim, 'output_dim': output_dim,
'dtype': dtype}
self.weight = self.params.get('weight', shape=(input_dim, output_dim),
init=weight_initializer,
dtype=None, init=weight_initializer,
allow_deferred_init=True)

def hybrid_forward(self, F, x, weight):
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/gluon/nn/conv_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,11 @@ def __init__(self, channels, kernel_size, strides, padding, dilation,
dshape[layout.find('C')] = in_channels
wshapes = _infer_weight_shape(op_name, dshape, self._kwargs)
self.weight = self.params.get('weight', shape=wshapes[1],
init=weight_initializer,
dtype=None, init=weight_initializer,
allow_deferred_init=True)
if use_bias:
self.bias = self.params.get('bias', shape=wshapes[2],
init=bias_initializer,
dtype=None, init=bias_initializer,
allow_deferred_init=True)
else:
self.bias = None
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
ctx = [ctx]
if init is None:
init = default_init if self.init is None else self.init
if not self.shape or np.prod(self.shape) <= 0:
if self.dtype is None or not self.shape or np.prod(self.shape) <= 0:
if self._allow_deferred_init:
self._deferred_init = (init, ctx, default_init)
return
Expand Down
24 changes: 12 additions & 12 deletions python/mxnet/gluon/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,16 +326,16 @@ def __init__(self, hidden_size, activation='tanh',
self._activation = activation
self._input_size = input_size
self.i2h_weight = self.params.get('i2h_weight', shape=(hidden_size, input_size),
init=i2h_weight_initializer,
dtype=None, init=i2h_weight_initializer,
allow_deferred_init=True)
self.h2h_weight = self.params.get('h2h_weight', shape=(hidden_size, hidden_size),
init=h2h_weight_initializer,
dtype=None, init=h2h_weight_initializer,
allow_deferred_init=True)
self.i2h_bias = self.params.get('i2h_bias', shape=(hidden_size,),
init=i2h_bias_initializer,
dtype=None, init=i2h_bias_initializer,
allow_deferred_init=True)
self.h2h_bias = self.params.get('h2h_bias', shape=(hidden_size,),
init=h2h_bias_initializer,
dtype=None, init=h2h_bias_initializer,
allow_deferred_init=True)

def state_info(self, batch_size=0):
Expand Down Expand Up @@ -434,16 +434,16 @@ def __init__(self, hidden_size,
self._hidden_size = hidden_size
self._input_size = input_size
self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, input_size),
init=i2h_weight_initializer,
dtype=None, init=i2h_weight_initializer,
allow_deferred_init=True)
self.h2h_weight = self.params.get('h2h_weight', shape=(4*hidden_size, hidden_size),
init=h2h_weight_initializer,
dtype=None, init=h2h_weight_initializer,
allow_deferred_init=True)
self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,),
init=i2h_bias_initializer,
dtype=None, init=i2h_bias_initializer,
allow_deferred_init=True)
self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,),
init=h2h_bias_initializer,
dtype=None, init=h2h_bias_initializer,
allow_deferred_init=True)

def state_info(self, batch_size=0):
Expand Down Expand Up @@ -541,16 +541,16 @@ def __init__(self, hidden_size,
self._hidden_size = hidden_size
self._input_size = input_size
self.i2h_weight = self.params.get('i2h_weight', shape=(3*hidden_size, input_size),
init=i2h_weight_initializer,
dtype=None, init=i2h_weight_initializer,
allow_deferred_init=True)
self.h2h_weight = self.params.get('h2h_weight', shape=(3*hidden_size, hidden_size),
init=h2h_weight_initializer,
dtype=None, init=h2h_weight_initializer,
allow_deferred_init=True)
self.i2h_bias = self.params.get('i2h_bias', shape=(3*hidden_size,),
init=i2h_bias_initializer,
dtype=None, init=i2h_bias_initializer,
allow_deferred_init=True)
self.h2h_bias = self.params.get('h2h_bias', shape=(3*hidden_size,),
init=h2h_bias_initializer,
dtype=None, init=h2h_bias_initializer,
allow_deferred_init=True)

def state_info(self, batch_size=0):
Expand Down
38 changes: 27 additions & 11 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,10 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
}

// construct backward graph
std::vector<NodeEntry> ograd_entries;
{
ograd_entries.reserve(fwd_graph_.outputs.size());
ograd_entries_.reserve(fwd_graph_.outputs.size());
for (size_t i = 0; i < fwd_graph_.outputs.size(); ++i) {
ograd_entries.emplace_back(NodeEntry{Node::Create(), 0, 0});
ograd_entries_.emplace_back(NodeEntry{Node::Create(), 0, 0});
}

std::vector<NodeEntry> xs;
Expand All @@ -77,7 +76,7 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
<< "There are no inputs in computation graph that require gradients.";

grad_graph_ = pass::Gradient(
fwd_graph_, fwd_graph_.outputs, xs, ograd_entries,
fwd_graph_, fwd_graph_.outputs, xs, ograd_entries_,
exec::AggregateGradient, nullptr, nullptr,
zero_ops, "_copy");
}
Expand Down Expand Up @@ -105,12 +104,12 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
std::make_shared<dmlc::any>(std::move(full_ref_count));

size_t num_forward_inputs = num_inputs();
for (uint32_t i = 0; i < ograd_entries.size(); ++i) {
if (!idx.exist(ograd_entries[i].node.get())) continue;
auto eid = idx.entry_id(ograd_entries[i]);
size_t num_forward_outputs = num_outputs();
for (uint32_t i = 0; i < ograd_entries_.size(); ++i) {
if (!idx.exist(ograd_entries_[i].node.get())) continue;
auto eid = idx.entry_id(ograd_entries_[i]);
if (ref_count[eid] > 0) {
bwd_ograd_dep_.push_back(i);
bwd_input_eid_.push_back(eid);
}
}
save_inputs_.resize(num_forward_inputs, false);
Expand All @@ -119,16 +118,14 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
if (ref_count[eid] > 0) {
save_inputs_[i] = true;
bwd_in_dep_.push_back(i);
bwd_input_eid_.push_back(eid);
}
}
save_outputs_.resize(idx.outputs().size(), false);
for (uint32_t i = 0; i < idx.outputs().size(); ++i) {
for (uint32_t i = 0; i < num_forward_outputs; ++i) {
auto eid = idx.entry_id(idx.outputs()[i]);
if (ref_count[eid] > 0) {
save_outputs_[i] = true;
bwd_out_dep_.push_back(i);
bwd_input_eid_.push_back(eid);
}
}
}
Expand Down Expand Up @@ -242,9 +239,28 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) {
if (curr_grad_req_[i]) g.outputs.emplace_back(grad_graph_.outputs[i]);
}
bwd_input_eid_.clear();
}

const auto& idx = g.indexed_graph();

if (bwd_input_eid_.size() != inputs.size()) {
bwd_input_eid_.clear();
for (const auto& i : bwd_ograd_dep_) {
auto eid = idx.entry_id(ograd_entries_[i]);
bwd_input_eid_.push_back(eid);
}
for (const auto& i : bwd_in_dep_) {
auto eid = idx.entry_id(idx.input_nodes()[i], 0);
bwd_input_eid_.push_back(eid);
}
for (const auto& i : bwd_out_dep_) {
auto eid = idx.entry_id(idx.outputs()[i]);
bwd_input_eid_.push_back(eid);
}
CHECK_EQ(inputs.size(), bwd_input_eid_.size());
}

size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes();
size_t num_forward_entries = fwd_graph_.indexed_graph().num_node_entries();

Expand Down
Loading

0 comments on commit 8dc09ec

Please sign in to comment.