4
4
from __future__ import print_function
5
5
6
6
import logging
7
+ import warnings
7
8
# Numpy support
8
9
import numpy as np
9
10
@@ -410,7 +411,7 @@ def _impl(inputs, attr, params):
410
411
def _decode_image ():
411
412
def _impl (inputs , attr , params ):
412
413
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
413
- print ("DecodeJpeg: It's a pass through, please handle preprocessing before input" )
414
+ warnings . warn ("DecodeJpeg: It's a pass through, please handle preprocessing before input" )
414
415
return inputs [0 ]
415
416
return _impl
416
417
@@ -1178,6 +1179,7 @@ class GraphProto(object):
1178
1179
def __init__ (self ):
1179
1180
self ._nodes = {}
1180
1181
self ._params = {}
1182
+ self ._input_shapes = {}
1181
1183
self ._output_shapes = {}
1182
1184
self ._num_param = 0
1183
1185
self ._num_rnn_layer = False
@@ -1229,36 +1231,55 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
1229
1231
raise NotImplementedError ( \
1230
1232
"The following operators are not implemented: {}" .format (missing_operators ))
1231
1233
1234
+ for node in graph .node :
1235
+ if node .op == 'Placeholder' :
1236
+ if shape and node .name in shape :
1237
+ self ._input_shapes [node .name ] = list (shape [node .name ])
1238
+ continue
1239
+ self ._input_shapes [node .name ] = \
1240
+ tensor_util .TensorShapeProtoToList (node .attr ['shape' ].shape )
1241
+ for idx , dim in enumerate (self ._input_shapes [node .name ]):
1242
+ if dim < 0 :
1243
+ self ._input_shapes [node .name ][idx ] = 1
1244
+ warnings .warn ("Use 1 instead of -1 in shape of operator %s."
1245
+ % node .name )
1246
+
1247
+ # Ignore user's input shape for Non placeholder
1248
+ elif node .op == 'Const' :
1249
+ tensor_value = node .attr ['value' ].tensor
1250
+ self ._input_shapes [node .name ] = \
1251
+ tensor_util .TensorShapeProtoToList (tensor_value .tensor_shape )
1252
+ if shape and node .name in shape :
1253
+ warnings .warn ("Ignore the passed shape. Shape in graphdef "
1254
+ "will be used for operator %s." % node .name )
1255
+
1232
1256
# Parse the nodes to re-create TF graph using Relay operators.
1233
1257
for node in graph .node :
1234
- # Tensorflow doesn't have seperate list for params extraction.
1258
+ # Tensorflow doesn't have separate list for params extraction.
1235
1259
# Operator name 'Const' is treated as a parameter to build params dict.
1236
1260
1237
1261
input_shapes = {}
1238
1262
attr = self ._parse_attr (node .attr )
1239
1263
1240
- #Variable converted to Const will not have only value attr
1264
+ # Variable converted to Const will not have only value attr
1241
1265
if 'value' in attr and node .op == 'Const' :
1242
- tensor_value = attr [ 'value' ]
1243
- self . _output_shapes [ node .name ] = \
1244
- [ tensor_util . TensorShapeProtoToList ( \
1245
- tensor_value . tensor_shape ) ]
1266
+ self . _output_shapes [ node . name ] = [ self . _input_shapes [ node . name ] ]
1267
+ elif shape and node .name in shape :
1268
+ # Give priority to user argument.
1269
+ self . _output_shapes [ node . name ] = [ shape [ node . name ] ]
1246
1270
elif '_output_shapes' in attr :
1247
1271
self ._output_shapes [node .name ] = \
1248
1272
[tensor_util .TensorShapeProtoToList (tshape ) \
1249
1273
for tshape in attr ['_output_shapes' ]]
1250
- elif shape :
1274
+ else :
1251
1275
# Keep the list indexable to avoid key error.
1252
1276
# Actual value will be filled after node creation.
1253
1277
self ._output_shapes [node .name ] = [None ]
1254
- else :
1255
- raise NotImplementedError ( \
1256
- "Please freeze the graph with add_shapes=True" )
1257
1278
1258
1279
if node .op == "Placeholder" :
1259
- self ._output_shapes [node .name ] = [shape [node .name ]]
1280
+ self ._output_shapes [node .name ] = [self . _input_shapes [node .name ]]
1260
1281
self ._nodes [node .name ] = [_expr .var (node .name ,
1261
- shape = self ._output_shapes [node .name ][ 0 ],
1282
+ shape = self ._input_shapes [node .name ],
1262
1283
dtype = attr ['dtype' ].name )]
1263
1284
1264
1285
elif node .op == "Const" :
@@ -1274,7 +1295,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
1274
1295
1275
1296
else :
1276
1297
# Pass the parsed shapes instead
1277
- attr ["_output_shapes" ] = self ._output_shapes [node .name ]
1298
+ attr ["_output_shapes" ] = output_shapes = self ._output_shapes [node .name ]
1278
1299
1279
1300
# Pass the node name too in attr
1280
1301
attr ["_node_name" ] = node .name
@@ -1301,7 +1322,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
1301
1322
1302
1323
op = self ._convert_operator (node .op , inputs , attr , graph )
1303
1324
1304
- # Check is op is converted to param
1325
+ # Check if op is converted to param
1305
1326
if isinstance (op , np .ndarray ):
1306
1327
self ._params [node .name ] = tvm .nd .array (op )
1307
1328
op = [_expr .var (node .name ,
@@ -1317,6 +1338,14 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
1317
1338
1318
1339
self ._nodes [node .name ] = op
1319
1340
1341
+ # Infer shapes even without specifying "add_shapes=True"
1342
+ if output_shapes == [None ]:
1343
+ out_type = ir_pass .infer_type (self ._nodes [node .name ][0 ])
1344
+ self ._output_shapes [node .name ] = [get_const_tuple (out_type .checked_type .shape )]
1345
+
1346
+ if self ._output_shapes [node .name ] and shape and node .name in shape :
1347
+ assert self ._output_shapes [node .name ] == list (shape [node .name ])
1348
+
1320
1349
# Infer shapes if passed explicitely
1321
1350
node_output = self ._nodes [node .name ]
1322
1351
out_type = ir_pass .infer_type (node_output [0 ])
0 commit comments