Skip to content

Commit

Permalink
Update frontend for keras 2.1.3 compatibility (apache#314)
Browse files Browse the repository at this point in the history
* Keras keeps renaming properties. Update frontend for keras 2.1.3 compatibility

* Add error message when inbound_nodes is not found
  • Loading branch information
thefiddler authored and tqchen committed May 29, 2018
1 parent d626eab commit 53361fa
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions nnvm/python/nnvm/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,14 @@ def _convert_concat(insym, keras_layer, _):


def _convert_reshape(insym, keras_layer, _):
return _sym.reshape(insym, keras_layer.shape)
shape = keras_layer.shape if hasattr(keras_layer, 'shape') else \
keras_layer.target_shape if hasattr(keras_layer, 'target_shape') else\
None

if shape is None:
raise TypeError("No shape attribute in reshape layer: {}".format(keras_layer))

return _sym.reshape(insym, shape=shape)


def _default_skip(insym, keras_layer, _): # pylint: disable=unused-argument
Expand Down Expand Up @@ -477,7 +484,15 @@ def from_keras(model):
symtab.get_var(keras_layer.name, must_contain=False)
else:
predecessors = []
for node in keras_layer.inbound_nodes:
inbound_nodes = keras_layer.inbound_nodes if hasattr(keras_layer, 'inbound_nodes') \
else keras_layer._inbound_nodes if hasattr(keras_layer, '_inbound_nodes') \
else None

if inbound_nodes is None:
raise TypeError("Unknown layer type or unsupported Keras version : {}"
.format(keras_layer))

for node in inbound_nodes:
for pred in node.inbound_layers:
predecessors.append(pred.name)
if len(predecessors) == 1:
Expand Down

0 comments on commit 53361fa

Please sign in to comment.