Skip to content

Commit 5f727ea

Browse files
committed
Review comments fixed
1 parent 2a87444 commit 5f727ea

File tree

2 files changed

+118
-42
lines changed

2 files changed

+118
-42
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# pylint: disable=invalid-name, unused-argument, too-many-lines, import-outside-toplevel
1818
"""Tensorflow lite frontend."""
1919
import math
20+
import itertools
2021
import numpy as np
2122
import tvm
2223
from tvm.ir import IRModule
@@ -1026,6 +1027,8 @@ def convert_gather(self, op):
10261027
raise ImportError("The tflite package must be installed")
10271028

10281029
input_tensors = self.get_input_tensors(op)
1030+
assert len(input_tensors) == 2, "input tensors length should be 2"
1031+
10291032
data = self.get_expr(input_tensors[0].tensor_idx)
10301033

10311034
indices = input_tensors[1]
@@ -1041,18 +1044,87 @@ def convert_gather(self, op):
10411044
gather_options.Init(op_options.Bytes, op_options.Pos)
10421045
axis = gather_options.Axis()
10431046

1044-
out = _op.take(data, indices, axis=axis)
1047+
# Check the indices are with in bounds.
1048+
data_shape = list(input_tensors[0].tensor.ShapeAsNumpy())
1049+
data_dim = len(data_shape)
1050+
1051+
axis_n = axis
1052+
if axis_n < 0:
1053+
axis_n += axis_n + data_dim
1054+
assert axis_n >= 0, "Axis out of bounds"
1055+
assert axis_n < data_dim, "Axis out of bounds"
1056+
1057+
indices_val = self.get_tensor_value(input_tensors[1])
1058+
indices_shape = list(indices_val.shape)
1059+
indices_len = len(indices_shape)
1060+
1061+
out_shape = []
1062+
for i in range(data_dim):
1063+
if axis_n == i:
1064+
for j in range(indices_len):
1065+
out_shape.append(indices_shape[j])
1066+
else:
1067+
out_shape.append(data_shape[i])
1068+
1069+
loopover = [range(s) for s in out_shape]
1070+
for idx in list(itertools.product(*loopover)):
1071+
indices_position = [idx[j] for j in range(axis_n, axis_n+indices_len)]
1072+
1073+
real_indices = [idx[j] for j in range(axis_n)]
1074+
real_indices.append(indices_val[tuple(indices_position)])
1075+
real_indices.extend([idx[j] for j in range(axis_n + indices_len, len(idx))])
1076+
for r, d in zip(real_indices, data_shape):
1077+
if r >= d:
1078+
raise ValueError("TFLite out of bound indices are not supported.")
1079+
1080+
# Use mode 'fast' since indices are already checked within bounds.
1081+
out = _op.take(data, indices, axis=axis, mode="fast")
10451082
return out
10461083

10471084
def convert_strided_slice(self, op):
1048-
"""Method to Convert TFLite STRIDED_SLICE operator"""
1085+
"""Method to Convert TFLite STRIDED_SLICE operator.
1086+
NOTE: Eventhough tensorflow supports begin_mask, end_mask, ellipsis_mask, new_axis_mask
1087+
and shrink_axis_mask, tflite doesn't support these and expect these values to be zero.
1088+
But in future, they may open up the mask implementation, so kept the implementation
1089+
same as tensorflow.
1090+
1091+
This op extracts a slice of size (end - begin) / stride from the given input tensor.
1092+
Starting at the location specified by begin the slice continues by adding stride to the
1093+
index until all dimensions are not less than end. Note that a stride can be negative,
1094+
which causes a reverse slice.
1095+
1096+
For slice input[val0, val1, ..., valn], begin/end/strides will be vectors of length n.
1097+
1098+
In each mask field(begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask)
1099+
the ith bit will correspond to the ith val.
1100+
1101+
If the ith bit of begin_mask is set, begin[i] is ignored and the fullest possible range
1102+
in that dimension is used instead.
1103+
1104+
If the ith bit of ellipsis_mask is set, as many unspecified dimensions as needed will be
1105+
inserted between other dimensions. Only one non-zero bit is allowed in ellipsis_mask.
1106+
1107+
If the ith bit of new_axis_mask is set, then begin, end, and stride are ignored and a
1108+
new length 1 dimension is added at this point in the output tensor.
1109+
1110+
If the ith bit of shrink_axis_mask is set, it implies that the ith specification shrinks
1111+
the dimensionality by 1, taking on the value at index begin[i]. end[i] and strides[i]
1112+
are ignored in this case.
1113+
begin and end are zero-indexed. strides entries must be non-zero.
1114+
1115+
TVM Relay implementation of doesn't support mask, so the mask values are processed in
1116+
this function and begin/end/strides are updated accordingly. If any mask is present, and
1117+
since tvm doesn't support mask computation directly, the output need a final reshape.
1118+
"""
10491119
try:
10501120
from tflite.BuiltinOptions import BuiltinOptions
10511121
from tflite.StridedSliceOptions import StridedSliceOptions
10521122
except ImportError:
10531123
raise ImportError("The tflite package must be installed")
10541124

10551125
input_tensors = self.get_input_tensors(op)
1126+
assert len(input_tensors) == 4, "input tensors length should be 4"
1127+
10561128
data_expr = self.get_expr(input_tensors[0].tensor_idx)
10571129

10581130
begin = list(self.get_tensor_value(input_tensors[1]))
@@ -1071,8 +1143,7 @@ def convert_strided_slice(self, op):
10711143

10721144
data_shape = list(input_tensors[0].tensor.ShapeAsNumpy())
10731145
data_dim = len(data_shape)
1074-
stride_dim = len(list(input_tensors[3].tensor.ShapeAsNumpy()))
1075-
1146+
stride_dim = len(stride)
10761147
def _transform_mask(stride_dim, ellipsis_mask):
10771148
"""Handle mask inputs to create new begin, end, stride and output shape"""
10781149
m_begin = [0] * data_dim

tests/python/frontend/tflite/test_forward.py

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -294,69 +294,74 @@ def test_forward_topk():
294294
# Gather
295295
# ------
296296

297-
def _test_gather(dshape, indices, axis, dtype):
297+
def _test_gather(dshape, indices, axis, dtype, quantized=False, oob=False):
298298
""" One iteration of Gather """
299-
data = np.random.uniform(1, 10, size=dshape).astype(dtype)
300299
indices = np.asarray(indices).astype('int32')
301-
302-
with tf.Graph().as_default():
303-
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
304-
out = array_ops.gather(in_data, indices, axis=axis)
305-
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
306-
307-
#Test quantized input
308-
data = np.random.uniform(1, 10, size=dshape).astype(np.uint8)
300+
data = np.random.uniform(1, 10, size=dshape)
301+
data = data.astype(np.uint8) if quantized else data.astype(dtype)
309302
with tf.Graph().as_default():
310303
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="in_data")
311-
out = array_ops.gather(in_data, indices, axis=axis)
312-
compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], quantized=True)
304+
if axis:
305+
out = array_ops.gather(in_data, indices, axis=axis)
306+
else:
307+
out = array_ops.gather(in_data, indices) #tflite conversion fails for None axis
308+
input_range = {'in_data': (-100, 100)} if quantized else None
309+
try:
310+
compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out],
311+
quantized=quantized, input_range=input_range)
312+
except ValueError as e:
313+
if not oob:
314+
raise e
315+
except Exception as e:
316+
raise e
313317

314318
def test_forward_gather():
315319
""" GATHER """
316-
_test_gather((4,), [1], 0, 'float32')
317-
_test_gather((1, 4), [0], 0, 'int32')
318-
_test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32')
319-
_test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'int32')
320-
_test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32')
321-
_test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'float32')
322-
_test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32')
323-
_test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32')
324-
_test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32')
320+
for quantized in [False, True]:
321+
_test_gather((4,), [1], 0, 'float32', quantized)
322+
_test_gather((4,), [1], None, 'int32', quantized)
323+
_test_gather((1, 4), [0], 0, 'int32', quantized)
324+
_test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32', quantized)
325+
_test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32', quantized)
326+
_test_gather((2, 2), [[[1, 0], [0, 1]]], None, 'float32', quantized)
327+
_test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32', quantized)
328+
_test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32', quantized)
329+
_test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32', quantized)
330+
_test_gather((3, 3, 3), [[[2, 1]]], -1, 'int32', quantized)
331+
_test_gather((4,), [16], 0, 'float32', quantized, oob=True)
332+
_test_gather((1, 3, 3), [12], 0, 'int32', quantized, oob=True)
333+
_test_gather((1, 3, 3), [20], 1, 'float32', quantized, oob=True)
334+
_test_gather((1, 3, 3), [20, 20], 2, 'float32', quantized, oob=True)
325335

326336
#######################################################################
327337
# StridedSlice
328338
# ------------
329339

330340
def _test_stridedslice(ip_shape, begin, end, stride, dtype,
331341
begin_mask=0, end_mask=0, new_axis_mask=0,
332-
shrink_axis_mask=0, ellipsis_mask=0):
342+
shrink_axis_mask=0, ellipsis_mask=0, quantized=False):
333343
""" One iteration of a Stridedslice """
334344
data = np.random.uniform(size=ip_shape).astype(dtype)
345+
data = data.astype(np.uint8) if quantized else data.astype(dtype)
335346
with tf.Graph().as_default():
336347
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
337348
out = array_ops.strided_slice(in_data, begin, end, stride,
338349
begin_mask=begin_mask,
339-
end_mask=end_mask, new_axis_mask=new_axis_mask,
340-
shrink_axis_mask=shrink_axis_mask,
341-
ellipsis_mask=ellipsis_mask)
342-
compare_tflite_with_tvm(data, 'in_data:0', [in_data], [out])
343-
344-
#Test with quantized inputs
345-
data = np.random.uniform(size=ip_shape).astype(np.uint8)
346-
with tf.Graph().as_default():
347-
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
348-
out = array_ops.strided_slice(in_data, begin, end, stride,
349-
begin_mask=begin_mask,
350-
end_mask=end_mask, new_axis_mask=new_axis_mask,
350+
end_mask=end_mask,
351+
new_axis_mask=new_axis_mask,
351352
shrink_axis_mask=shrink_axis_mask,
352353
ellipsis_mask=ellipsis_mask)
353-
compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], quantized=True)
354+
input_range = {'in_data': (-100, 100)} if quantized else None
355+
compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], quantized=quantized,
356+
input_range=input_range)
354357

355358
def test_forward_stridedslice():
356359
'''test StridedSlice'''
357-
_test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1)
358-
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32')
359-
_test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=1)
360+
for quantized in [False, True]:
361+
_test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1, quantized=quantized)
362+
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32', quantized=quantized)
363+
_test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=0, quantized=quantized)
364+
_test_stridedslice((4, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=2, quantized=quantized)
360365

361366
#######################################################################
362367
# transpose

0 commit comments

Comments
 (0)