Skip to content

Commit

Permalink
fix symbolblock (#8050)
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong authored Sep 27, 2017
1 parent f65da2c commit 494a642
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
6 changes: 3 additions & 3 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,8 @@ def __init__(self, outputs, inputs, params=None):
self._params = ParameterDict('', params)
if isinstance(inputs, symbol.Symbol) and len(inputs.list_outputs()) == 1:
inputs = [inputs]
if isinstance(outputs, symbol.Symbol) and len(outputs.list_outputs()) == 1:
outputs = [outputs]
if isinstance(outputs, (list, tuple)) and len(outputs) == 1:
outputs = outputs[0]

syms, self._in_format = _flatten(inputs)
out, self._out_format = _flatten(outputs)
Expand Down Expand Up @@ -523,7 +523,7 @@ def forward(self, x, *args):
assert in_fmt == self._in_format, "Invalid input format"
ret = copy.copy(self._cached_graph[1])
ret._compose(**{k.name: v for k, v in zip(self._cached_graph[0], args)})
return _regroup(ret, self._out_format)[0]
return _regroup(list(ret), self._out_format)[0]

def hybrid_forward(self, F, x, *args, **kwargs):
raise NotImplementedError
22 changes: 21 additions & 1 deletion tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,27 @@ def test_symbol_block():
assert len(smodel(mx.nd.zeros((16, 10)))) == 14

out = smodel(mx.sym.var('in'))
assert len(out.get_internals().list_outputs()) == len(outputs.list_outputs())
assert len(out) == len(outputs.list_outputs())

class Net(nn.HybridBlock):
def __init__(self, model):
super(Net, self).__init__()
self.model = model

def hybrid_forward(self, F, x):
out = self.model(x)
return F.add_n(*[i.sum() for i in out])

net = Net(smodel)
net.hybridize()
assert isinstance(net(mx.nd.zeros((16, 10))), mx.nd.NDArray)

inputs = mx.sym.var('data')
outputs = model(inputs)
smodel = gluon.SymbolBlock(outputs, inputs, params=model.collect_params())
net = Net(smodel)
net.hybridize()
assert isinstance(net(mx.nd.zeros((16, 10))), mx.nd.NDArray)


def check_layer_forward(layer, dshape):
Expand Down

0 comments on commit 494a642

Please sign in to comment.