Skip to content

Commit 4b2c8fc

Browse files
committed
Review comments fixed
1 parent d5bbcc1 commit 4b2c8fc

File tree

2 files changed

+94
-42
lines changed

2 files changed

+94
-42
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
"""Tensorflow lite frontend."""
2020
import math
21+
import itertools
2122
import numpy as np
2223
import tvm
2324
from tvm.ir import IRModule
@@ -864,6 +865,8 @@ def convert_gather(self, op):
864865
raise ImportError("The tflite package must be installed")
865866

866867
input_tensors = self.get_input_tensors(op)
868+
assert len(input_tensors) == 2, "input tensors length should be 2"
869+
867870
data = self.get_expr(input_tensors[0].tensor_idx)
868871

869872
indices = input_tensors[1]
@@ -879,18 +882,63 @@ def convert_gather(self, op):
879882
gather_options.Init(op_options.Bytes, op_options.Pos)
880883
axis = gather_options.Axis()
881884

882-
out = _op.take(data, indices, axis=axis)
885+
# Check the indices are oob, tflite is unpredictable in case of oob.
886+
data_shape = list(input_tensors[0].tensor.ShapeAsNumpy())
887+
data_dim = len(data_shape)
888+
889+
axis_n = axis
890+
if axis_n < 0:
891+
axis_n += axis_n + data_dim
892+
assert axis_n >= 0, "Axis out of bounds"
893+
assert axis_n < data_dim, "Axis out of bounds"
894+
895+
indices_val = self.get_tensor_value(input_tensors[1])
896+
indices_shape = list(indices_val.shape)
897+
indices_len = len(indices_shape)
898+
899+
out_shape = []
900+
for i in range(data_dim):
901+
if axis_n == i:
902+
for j in range(indices_len):
903+
out_shape.append(indices_shape[j])
904+
else:
905+
out_shape.append(data_shape[i])
906+
907+
loopover = [range(s) for s in out_shape]
908+
for idx in list(itertools.product(*loopover)):
909+
indices_position = [idx[j] for j in range(axis_n, axis_n+indices_len)]
910+
911+
real_indices = [idx[j] for j in range(axis_n)]
912+
real_indices.append(indices_val[tuple(indices_position)])
913+
real_indices.extend([idx[j] for j in range(axis_n + indices_len, len(idx))])
914+
for r, d in zip(real_indices, data_shape):
915+
if r >= d:
916+
raise ValueError("TFLite out of bound indices are not supported.")
917+
918+
# Use mode as fast since already checked for oob.
919+
out = _op.take(data, indices, axis=axis, mode="fast")
883920
return out
884921

885922
def convert_strided_slice(self, op):
886-
"""Method to Convert TFLite STRIDED_SLICE operator"""
923+
"""Method to Convert TFLite STRIDED_SLICE operator.
924+
Note: Eventhough tf2.0 supports begin_mask, end_mask, ellipsis_mask, new_axis_mask
925+
and shrink_axis_mask, tflite doesn't support these and expect these values to be zero.
926+
But in future, they may open up the mask implementation, so kept the implementation
927+
same as tensorflow.
928+
TVM Relay implementation of doesn't support mask, so the mask values are processed here
929+
and begin/end/strides are updated accordingly similar to tensorflow parsing. If mask is
930+
present, the values are sliced correctly in order by tvm relay implementation, but
931+
there may be a shape mismatch, which is fixed via the final reshape.
932+
"""
887933
try:
888934
from tflite.BuiltinOptions import BuiltinOptions
889935
from tflite.StridedSliceOptions import StridedSliceOptions
890936
except ImportError:
891937
raise ImportError("The tflite package must be installed")
892938

893939
input_tensors = self.get_input_tensors(op)
940+
assert len(input_tensors) == 4, "input tensors length should be 4"
941+
894942
data_expr = self.get_expr(input_tensors[0].tensor_idx)
895943

896944
begin = list(self.get_tensor_value(input_tensors[1]))
@@ -909,8 +957,7 @@ def convert_strided_slice(self, op):
909957

910958
data_shape = list(input_tensors[0].tensor.ShapeAsNumpy())
911959
data_dim = len(data_shape)
912-
stride_dim = len(list(input_tensors[3].tensor.ShapeAsNumpy()))
913-
960+
stride_dim = len(stride)
914961
def _transform_mask(stride_dim, ellipsis_mask):
915962
"""Handle mask inputs to create new begin, end, stride and output shape"""
916963
m_begin = [0] * data_dim

tests/python/frontend/tflite/test_forward.py

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -271,69 +271,74 @@ def test_forward_slice():
271271
# Gather
272272
# ------
273273

274-
def _test_gather(dshape, indices, axis, dtype):
274+
def _test_gather(dshape, indices, axis, dtype, quantized=False, oob=False):
275275
""" One iteration of Gather """
276-
data = np.random.uniform(1, 10, size=dshape).astype(dtype)
277276
indices = np.asarray(indices).astype('int32')
278-
279-
with tf.Graph().as_default():
280-
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
281-
out = array_ops.gather(in_data, indices, axis=axis)
282-
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
283-
284-
#Test quantized input
285-
data = np.random.uniform(1, 10, size=dshape).astype(np.uint8)
277+
data = np.random.uniform(1, 10, size=dshape)
278+
data = data.astype(np.uint8) if quantized else data.astype(dtype)
286279
with tf.Graph().as_default():
287280
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="in_data")
288-
out = array_ops.gather(in_data, indices, axis=axis)
289-
compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], quantized=True)
281+
if axis:
282+
out = array_ops.gather(in_data, indices, axis=axis)
283+
else:
284+
out = array_ops.gather(in_data, indices) #tflite conversion fails for None axis
285+
input_range = {'in_data': (-100, 100)} if quantized else None
286+
try:
287+
compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out],
288+
quantized=quantized, input_range=input_range)
289+
except ValueError as e:
290+
if not oob:
291+
raise e
292+
except Exception as e:
293+
raise e
290294

291295
def test_forward_gather():
292296
""" GATHER """
293-
_test_gather((4,), [1], 0, 'float32')
294-
_test_gather((1, 4), [0], 0, 'int32')
295-
_test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32')
296-
_test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'int32')
297-
_test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32')
298-
_test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'float32')
299-
_test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32')
300-
_test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32')
301-
_test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32')
297+
for quantized in [False, True]:
298+
_test_gather((4,), [1], 0, 'float32', quantized)
299+
_test_gather((4,), [1], None, 'int32', quantized)
300+
_test_gather((1, 4), [0], 0, 'int32', quantized)
301+
_test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32', quantized)
302+
_test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32', quantized)
303+
_test_gather((2, 2), [[[1, 0], [0, 1]]], None, 'float32', quantized)
304+
_test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32', quantized)
305+
_test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32', quantized)
306+
_test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32', quantized)
307+
_test_gather((3, 3, 3), [[[2, 1]]], -1, 'int32', quantized)
308+
_test_gather((4,), [16], 0, 'float32', quantized, oob=True)
309+
_test_gather((1, 3, 3), [12], 0, 'int32', quantized, oob=True)
310+
_test_gather((1, 3, 3), [20], 1, 'float32', quantized, oob=True)
311+
_test_gather((1, 3, 3), [20, 20], 2, 'float32', quantized, oob=True)
302312

303313
#######################################################################
304314
# StridedSlice
305315
# ------------
306316

307317
def _test_stridedslice(ip_shape, begin, end, stride, dtype,
308318
begin_mask=0, end_mask=0, new_axis_mask=0,
309-
shrink_axis_mask=0, ellipsis_mask=0):
319+
shrink_axis_mask=0, ellipsis_mask=0, quantized=False):
310320
""" One iteration of a Stridedslice """
311321
data = np.random.uniform(size=ip_shape).astype(dtype)
322+
data = data.astype(np.uint8) if quantized else data.astype(dtype)
312323
with tf.Graph().as_default():
313324
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
314325
out = array_ops.strided_slice(in_data, begin, end, stride,
315326
begin_mask=begin_mask,
316-
end_mask=end_mask, new_axis_mask=new_axis_mask,
317-
shrink_axis_mask=shrink_axis_mask,
318-
ellipsis_mask=ellipsis_mask)
319-
compare_tflite_with_tvm(data, 'in_data:0', [in_data], [out])
320-
321-
#Test with quantized inputs
322-
data = np.random.uniform(size=ip_shape).astype(np.uint8)
323-
with tf.Graph().as_default():
324-
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
325-
out = array_ops.strided_slice(in_data, begin, end, stride,
326-
begin_mask=begin_mask,
327-
end_mask=end_mask, new_axis_mask=new_axis_mask,
327+
end_mask=end_mask,
328+
new_axis_mask=new_axis_mask,
328329
shrink_axis_mask=shrink_axis_mask,
329330
ellipsis_mask=ellipsis_mask)
330-
compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], quantized=True)
331+
input_range = {'in_data': (-100, 100)} if quantized else None
332+
compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], quantized=quantized,
333+
input_range=input_range)
331334

332335
def test_forward_stridedslice():
333336
'''test StridedSlice'''
334-
_test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1)
335-
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32')
336-
_test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=1)
337+
for quantized in [False, True]:
338+
_test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1, quantized=quantized)
339+
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32', quantized=quantized)
340+
_test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=0, quantized=quantized)
341+
_test_stridedslice((4, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=2, quantized=quantized)
337342

338343
#######################################################################
339344
# transpose

0 commit comments

Comments
 (0)