Skip to content

Commit 6ad7ce8

Browse files
ariwaranosaisrkreddy1238
authored andcommitted
Add CONCATENATION to tflite frontend, support Inception V3 (#2643)
* Add CONCATENATION to tflite frontend * fix typo * Fix codestyle * Fix code style * simplify convert map * Update
1 parent a1b8610 commit 6ad7ce8

File tree

2 files changed

+119
-2
lines changed

2 files changed

+119
-2
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def __init__(self, model, subgraph, exp_tab):
3535
self.builtin_op_code = build_str_map(BuiltinOperator())
3636
self.activation_fn_type = build_str_map(ActivationFunctionType())
3737
self.builtin_options = build_str_map(BuiltinOptions())
38+
39+
# Add more operators
3840
self.convert_map = {
3941
'CONV_2D': self.convert_conv2d,
4042
'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d,
@@ -43,7 +45,7 @@ def __init__(self, model, subgraph, exp_tab):
4345
'SOFTMAX': self.convert_softmax,
4446
'SQUEEZE': self.convert_squeeze,
4547
'MAX_POOL_2D': self.convert_max_pool2d,
46-
# Add more operators
48+
"CONCATENATION": self.convert_concatenation
4749
}
4850

4951
def check_unsupported_ops(self):
@@ -245,6 +247,48 @@ def convert_softmax(self, op):
245247

246248
return out
247249

250+
def convert_concatenation(self, op):
251+
""" convert TFLite concatenation"""
252+
try:
253+
from tflite.Operator import Operator
254+
from tflite.ConcatenationOptions import ConcatenationOptions
255+
from tflite.BuiltinOptions import BuiltinOptions
256+
from tflite.ActivationFunctionType import ActivationFunctionType
257+
except ImportError:
258+
raise ImportError("The tflite package must be installed")
259+
260+
assert isinstance(op, Operator)
261+
input_tensors = self.get_input_tensors(op)
262+
assert len(input_tensors) >= 1, "input tensors should greater than 1"
263+
in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors]
264+
265+
output_tensors = self.get_output_tensors(op)
266+
assert len(output_tensors) == 1, "output tensors should be 1"
267+
268+
assert op.BuiltinOptionsType() == BuiltinOptions.ConcatenationOptions
269+
op_options = op.BuiltinOptions()
270+
concatenation_options = ConcatenationOptions()
271+
concatenation_options.Init(op_options.Bytes, op_options.Pos)
272+
concatenation_axis = concatenation_options.Axis()
273+
fused_activation_fn = concatenation_options.FusedActivationFunction()
274+
input_shape_length = len(input_tensors[0].tensor.ShapeAsNumpy())
275+
276+
# TFLite is N H W C, our layout is N C H W
277+
if input_shape_length <= 4:
278+
axis_convert_map = [0] + list(range(2, input_shape_length)) + [1]
279+
concatenation_axis = axis_convert_map[concatenation_axis]
280+
else:
281+
raise NotImplementedError("Not support input shape length {} of concatenatio : "
282+
.format(str(input_shape_length)))
283+
284+
# with axis in N H W C
285+
out = _op.concatenate(in_exprs, axis=concatenation_axis)
286+
287+
# if we have activation fn
288+
if fused_activation_fn != ActivationFunctionType.NONE:
289+
out = self.convert_fused_activation_function(out, fused_activation_fn)
290+
return out
291+
248292
def convert_squeeze(self, op):
249293
"""Convert TFLite squeeze"""
250294
try:

tests/python/frontend/tflite/test_forward.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,53 @@ def test_forward_reshape():
283283
_test_reshape(np.arange(6), [-1])
284284

285285

286+
#######################################################################
287+
# Concatenation
288+
# -------------
289+
290+
def _test_concatenation(data, axis):
291+
""" One iteration of concatenation """
292+
293+
assert len(data) >= 1
294+
need_transpose = False
295+
if len(data[0].shape) == 1 or len(data[0].shape) == 2:
296+
tvm_data = data
297+
elif len(data[0].shape) == 3:
298+
#need_transpose = True
299+
tvm_data = [np.transpose(d, axes=(0, 2, 1)) for d in data]
300+
elif len(data[0].shape) == 4:
301+
need_transpose = True
302+
tvm_data = [np.transpose(d, axes=(0, 3, 1, 2)) for d in data]
303+
else:
304+
raise NotImplementedError("Not support input shape {} of reshape : ".
305+
format(str(len(data))))
306+
307+
with tf.Graph().as_default():
308+
in_data = [
309+
array_ops.placeholder(shape=tensor.shape, dtype=tensor.dtype, name="in_{}".format(idx))
310+
for idx, tensor in enumerate(data)]
311+
out = array_ops.concat(in_data, axis=axis)
312+
name = ["in_{}:0".format(idx) for idx in range(len(data))]
313+
314+
compare_tflite_with_tvm(data, tvm_data, name, in_data, [out], need_transpose)
315+
316+
317+
def test_forward_concatenation():
318+
319+
_test_concatenation(
320+
[np.arange(6).reshape((1, 2, 1, 3)),
321+
np.arange(6).reshape((1, 2, 1, 3))], 1)
322+
323+
_test_concatenation(
324+
[np.arange(6).reshape((3, 2)),
325+
np.arange(6).reshape((3, 2))], 1)
326+
327+
_test_concatenation(
328+
[np.arange(6).reshape((2, 1, 1, 3)),
329+
np.arange(6).reshape((2, 1, 1, 3)),
330+
np.arange(6).reshape((2, 1, 1, 3))], 1)
331+
332+
286333
#######################################################################
287334
# Squeeze
288335
# -------
@@ -340,26 +387,51 @@ def test_forward_softmax():
340387
#######################################################################
341388
# Mobilenet
342389
# ---------
390+
343391
def test_forward_mobilenet():
344392
'''test mobilenet v1 tflite model'''
345393
# MobilenetV1
346394
temp = util.tempdir()
347395
tflite_model_file = tf_testing.get_workload_official(
348396
"http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz",
349397
"mobilenet_v1_1.0_224.tflite", temp)
350-
tflite_model_buf = open(tflite_model_file, "rb").read()
398+
with open(tflite_model_file, "rb") as f:
399+
tflite_model_buf = f.read()
351400
data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
352401
tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
353402
tflite_output = run_tflite_graph(tflite_model_buf, data)
354403
tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input')
355404
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
356405
rtol=1e-5, atol=1e-5)
406+
temp.remove()
407+
408+
#######################################################################
409+
# Inception V3
410+
# ------------
411+
412+
def test_forward_inception_v3_net():
413+
'''test inception v3 tflite model'''
414+
# InceptionV3
415+
temp = util.tempdir()
416+
tflite_model_file = tf_testing.get_workload_official(
417+
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz",
418+
"inception_v3.tflite", temp)
419+
with open(tflite_model_file, "rb") as f:
420+
tflite_model_buf = f.read()
421+
data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
422+
tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
423+
tflite_output = run_tflite_graph(tflite_model_buf, data)
424+
tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input')
425+
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
426+
rtol=1e-5, atol=1e-5)
427+
temp.remove()
357428

358429
#######################################################################
359430
# Main
360431
# ----
361432
if __name__ == '__main__':
362433
# Transforms
434+
test_forward_concatenation()
363435
test_forward_reshape()
364436
test_forward_squeeze()
365437

@@ -370,3 +442,4 @@ def test_forward_mobilenet():
370442

371443
# End to End
372444
test_forward_mobilenet()
445+
test_forward_inception_v3_net()

0 commit comments

Comments
 (0)