Skip to content

Commit f7eff09

Browse files
yongwwwsrkreddy1238
authored andcommitted
[relay][frontend] TensorFlow saved model support (#2586)
* [relay][frontend] TensorFlow saved model support * Add Examples section * keep one copy of tensorflow_parser in relay
1 parent 19194e9 commit f7eff09

File tree

3 files changed

+62
-25
lines changed

3 files changed

+62
-25
lines changed

nnvm/python/nnvm/frontend/util/__init__.py

Whitespace-only changes.

python/tvm/relay/frontend/tensorflow.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from __future__ import print_function
55

66
import logging
7+
import warnings
78
# Numpy support
89
import numpy as np
910

@@ -410,7 +411,7 @@ def _impl(inputs, attr, params):
410411
def _decode_image():
411412
def _impl(inputs, attr, params):
412413
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
413-
print("DecodeJpeg: It's a pass through, please handle preprocessing before input")
414+
warnings.warn("DecodeJpeg: It's a pass through, please handle preprocessing before input")
414415
return inputs[0]
415416
return _impl
416417

@@ -1178,6 +1179,7 @@ class GraphProto(object):
11781179
def __init__(self):
11791180
self._nodes = {}
11801181
self._params = {}
1182+
self._input_shapes = {}
11811183
self._output_shapes = {}
11821184
self._num_param = 0
11831185
self._num_rnn_layer = False
@@ -1229,36 +1231,55 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
12291231
raise NotImplementedError( \
12301232
"The following operators are not implemented: {}".format(missing_operators))
12311233

1234+
for node in graph.node:
1235+
if node.op == 'Placeholder':
1236+
if shape and node.name in shape:
1237+
self._input_shapes[node.name] = list(shape[node.name])
1238+
continue
1239+
self._input_shapes[node.name] = \
1240+
tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
1241+
for idx, dim in enumerate(self._input_shapes[node.name]):
1242+
if dim < 0:
1243+
self._input_shapes[node.name][idx] = 1
1244+
warnings.warn("Use 1 instead of -1 in shape of operator %s."
1245+
% node.name)
1246+
1247+
# Ignore user's input shape for Non placeholder
1248+
elif node.op == 'Const':
1249+
tensor_value = node.attr['value'].tensor
1250+
self._input_shapes[node.name] = \
1251+
tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape)
1252+
if shape and node.name in shape:
1253+
warnings.warn("Ignore the passed shape. Shape in graphdef "
1254+
"will be used for operator %s." % node.name)
1255+
12321256
# Parse the nodes to re-create TF graph using Relay operators.
12331257
for node in graph.node:
1234-
# Tensorflow doesn't have seperate list for params extraction.
1258+
# Tensorflow doesn't have separate list for params extraction.
12351259
# Operator name 'Const' is treated as a parameter to build params dict.
12361260

12371261
input_shapes = {}
12381262
attr = self._parse_attr(node.attr)
12391263

1240-
#Variable converted to Const will not have only value attr
1264+
# Variable converted to Const will not have only value attr
12411265
if 'value' in attr and node.op == 'Const':
1242-
tensor_value = attr['value']
1243-
self._output_shapes[node.name] = \
1244-
[tensor_util.TensorShapeProtoToList( \
1245-
tensor_value.tensor_shape)]
1266+
self._output_shapes[node.name] = [self._input_shapes[node.name]]
1267+
elif shape and node.name in shape:
1268+
# Give priority to user argument.
1269+
self._output_shapes[node.name] = [shape[node.name]]
12461270
elif '_output_shapes' in attr:
12471271
self._output_shapes[node.name] = \
12481272
[tensor_util.TensorShapeProtoToList(tshape) \
12491273
for tshape in attr['_output_shapes']]
1250-
elif shape:
1274+
else:
12511275
# Keep the list indexable to avoid key error.
12521276
# Actual value will be filled after node creation.
12531277
self._output_shapes[node.name] = [None]
1254-
else:
1255-
raise NotImplementedError( \
1256-
"Please freeze the graph with add_shapes=True")
12571278

12581279
if node.op == "Placeholder":
1259-
self._output_shapes[node.name] = [shape[node.name]]
1280+
self._output_shapes[node.name] = [self._input_shapes[node.name]]
12601281
self._nodes[node.name] = [_expr.var(node.name,
1261-
shape=self._output_shapes[node.name][0],
1282+
shape=self._input_shapes[node.name],
12621283
dtype=attr['dtype'].name)]
12631284

12641285
elif node.op == "Const":
@@ -1274,7 +1295,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
12741295

12751296
else:
12761297
# Pass the parsed shapes instead
1277-
attr["_output_shapes"] = self._output_shapes[node.name]
1298+
attr["_output_shapes"] = output_shapes = self._output_shapes[node.name]
12781299

12791300
# Pass the node name too in attr
12801301
attr["_node_name"] = node.name
@@ -1301,7 +1322,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
13011322

13021323
op = self._convert_operator(node.op, inputs, attr, graph)
13031324

1304-
# Check is op is converted to param
1325+
# Check if op is converted to param
13051326
if isinstance(op, np.ndarray):
13061327
self._params[node.name] = tvm.nd.array(op)
13071328
op = [_expr.var(node.name,
@@ -1317,6 +1338,14 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
13171338

13181339
self._nodes[node.name] = op
13191340

1341+
# Infer shapes even without specifying "add_shapes=True"
1342+
if output_shapes == [None]:
1343+
out_type = ir_pass.infer_type(self._nodes[node.name][0])
1344+
self._output_shapes[node.name] = [get_const_tuple(out_type.checked_type.shape)]
1345+
1346+
if self._output_shapes[node.name] and shape and node.name in shape:
1347+
assert self._output_shapes[node.name] == list(shape[node.name])
1348+
13201349
# Infer shapes if passed explicitely
13211350
node_output = self._nodes[node.name]
13221351
out_type = ir_pass.infer_type(node_output[0])

nnvm/python/nnvm/frontend/util/tensorflow_parser.py renamed to python/tvm/relay/frontend/tensorflow_parser.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,21 @@
77

88

99
class TFParser(object):
10-
"""A Wrapper to handle tensorflow models parsing
11-
TensorFlow is needed
12-
```
13-
parser = TfParser(model_dir)
14-
graph = parser.parse()
15-
```
10+
"""
11+
A Wrapper to handle tensorflow models parsing, TensorFlow is needed
12+
1613
Parameters
1714
----------
1815
model_dir : tensorflow frozen pb file or a directory that contains saved
1916
model or checkpoints.
17+
18+
Examples
19+
--------
20+
.. code-block:: python
21+
22+
parser = TfParser(model_dir)
23+
graph = parser.parse()
24+
# graph is related graphdef of the model
2025
"""
2126

2227
def __init__(self, model_dir):
@@ -115,13 +120,16 @@ def _load_ckpt(self):
115120
"""TODO: Load checkpoint model."""
116121
raise RuntimeError("InputConfiguration: Loading tf checkpoint model is "
117122
"not supported yet.")
118-
# pylint: disable=unreachable
119-
return 0
120123

121124
def parse(self):
122-
"""Parse tensorflow models: checkpoints, saved models, and single pb
123-
file.
124125
"""
126+
Parse tensorflow models: checkpoints, saved models, and single frozen pb file.
127+
128+
Returns
129+
-------
130+
GraphDef of the passed model
131+
"""
132+
125133
graph = None
126134

127135
if os.path.isdir(self._model_dir):

0 commit comments

Comments
 (0)