Skip to content

Commit 6321119

Browse files
committed
[FRONTEND][TFLITE]Gather, StridedSlice op added
1 parent 73a9e99 commit 6321119

File tree

2 files changed

+213
-0
lines changed

2 files changed

+213
-0
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ def __init__(self, model, subgraph, exp_tab):
117117
'PRELU': self.convert_prelu,
118118
'TRANSPOSE_CONV': self.convert_transpose_conv,
119119
'SQUARED_DIFFERENCE': self.convert_squared_difference,
120+
'GATHER': self.convert_gather,
121+
'STRIDED_SLICE': self.convert_strided_slice,
120122
}
121123

122124
def check_unsupported_ops(self):
@@ -792,6 +794,147 @@ def convert_not_equal(self, op):
792794
'TFlite quantized NOT_EQUAL operator is not supported yet.')
793795
return self._convert_elemwise(_op.not_equal, op)
794796

797+
def convert_gather(self, op):
798+
"""Method to Convert TFLite GATHER operator"""
799+
try:
800+
from tflite.BuiltinOptions import BuiltinOptions
801+
from tflite.GatherOptions import GatherOptions
802+
from tflite.TensorType import TensorType
803+
except ImportError:
804+
raise ImportError("The tflite package must be installed")
805+
806+
input_tensors = self.get_input_tensors(op)
807+
data = self.get_expr(input_tensors[0].tensor_idx)
808+
809+
indices = input_tensors[1]
810+
indices_type = indices.tensor.Type()
811+
assert indices_type in (TensorType.INT32, TensorType.INT64)
812+
indices_type_str = self.get_tensor_type_str(indices_type)
813+
indices = self.exp_tab.new_const(self.get_tensor_value(indices),
814+
dtype=indices_type_str)
815+
816+
assert op.BuiltinOptionsType() == BuiltinOptions.GatherOptions
817+
op_options = op.BuiltinOptions()
818+
gather_options = GatherOptions()
819+
gather_options.Init(op_options.Bytes, op_options.Pos)
820+
axis = gather_options.Axis()
821+
822+
out = _op.take(data, indices, axis=axis)
823+
return out
824+
825+
def convert_strided_slice(self, op):
826+
"""Method to Convert TFLite STRIDED_SLICE operator"""
827+
try:
828+
from tflite.BuiltinOptions import BuiltinOptions
829+
from tflite.StridedSliceOptions import StridedSliceOptions
830+
except ImportError:
831+
raise ImportError("The tflite package must be installed")
832+
833+
input_tensors = self.get_input_tensors(op)
834+
data_expr = self.get_expr(input_tensors[0].tensor_idx)
835+
836+
begin = list(self.get_tensor_value(input_tensors[1]))
837+
end = list(self.get_tensor_value(input_tensors[2]))
838+
stride = list(self.get_tensor_value(input_tensors[3]))
839+
840+
assert op.BuiltinOptionsType() == BuiltinOptions.StridedSliceOptions
841+
op_options = op.BuiltinOptions()
842+
options = StridedSliceOptions()
843+
options.Init(op_options.Bytes, op_options.Pos)
844+
begin_mask = options.BeginMask()
845+
end_mask = options.EndMask()
846+
ellipsis_mask = options.EllipsisMask()
847+
new_axis_mask = options.NewAxisMask()
848+
shrink_axis_mask = options.ShrinkAxisMask()
849+
850+
data_shape = list(input_tensors[0].tensor.ShapeAsNumpy())
851+
data_dim = len(data_shape)
852+
stride_dim = len(list(input_tensors[3].tensor.ShapeAsNumpy()))
853+
854+
def _transform_mask(stride_dim, ellipsis_mask):
855+
"""Handle mask inputs to create new begin, end, stride and output shape"""
856+
m_begin = [0] * data_dim
857+
m_end = [0] * data_dim
858+
m_stride = [0] * data_dim
859+
fshape_indices = []
860+
#Count new axis after ellipsis_mask, consider while applying ellipsis_mask.
861+
ellipsis_seen = False
862+
new_axes_after_ellipsis = 0
863+
for i in range(stride_dim):
864+
mask = 1 << i
865+
if ellipsis_seen and (mask & new_axis_mask) != 0:
866+
new_axes_after_ellipsis += 1
867+
if (mask & ellipsis_mask) != 0:
868+
ellipsis_seen = True
869+
if not ellipsis_seen:
870+
#Used later for extending the stride attributes in the below loop.
871+
ellipsis_mask |= (1 << stride_dim)
872+
stride_dim += 1
873+
final_index = 0
874+
for index in range(stride_dim):
875+
mask = 1 << index
876+
if mask & ellipsis_mask:
877+
#Identify the end index for applying ellipsis_mask
878+
to_index = min(((data_dim - (stride_dim-index)) + 1 \
879+
+ new_axes_after_ellipsis), data_dim)
880+
for i in range(final_index, to_index):
881+
m_begin[final_index] = 0
882+
m_end[final_index] = data_shape[final_index]
883+
m_stride[final_index] = 1
884+
fshape_indices.append(final_index)
885+
final_index += 1
886+
elif mask &new_axis_mask:
887+
fshape_indices.append(-1)
888+
elif not mask & new_axis_mask:
889+
if final_index == len(m_begin):
890+
break
891+
if mask & begin_mask:
892+
m_begin[final_index] = data_shape[final_index] \
893+
if stride[index] < 0 else 0
894+
elif begin[index]:
895+
m_begin[final_index] = begin[index]
896+
if mask & end_mask:
897+
m_end[final_index] = 0 if stride[index] < 0 \
898+
else data_shape[final_index]
899+
elif end[index]:
900+
m_end[final_index] = end[index]
901+
m_stride[final_index] = stride[index]
902+
if mask & shrink_axis_mask:
903+
#Tensorflow make axis with shrink_axis_mask as dimension 1
904+
m_begin[final_index] = data_shape[final_index] + begin[index] \
905+
if begin[index] < 0 else begin[index]
906+
m_end[final_index] = begin[index] + 1
907+
m_stride[final_index] = 1
908+
fshape_indices.append(-2)
909+
else:
910+
fshape_indices.append(final_index)
911+
912+
final_index += 1
913+
return m_begin, m_end, m_stride, fshape_indices
914+
915+
fshape_indices = None
916+
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
917+
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
918+
919+
out = _op.strided_slice(data_expr, begin=begin, end=end, strides=stride)
920+
out_shape = _infer_shape(out)
921+
if not fshape_indices:
922+
fshape_indices = range(len(out_shape))
923+
924+
#Create final output shape.
925+
final_output = []
926+
for gather_index in fshape_indices:
927+
if gather_index == -1:
928+
final_output.append(1)
929+
elif gather_index == -2:
930+
pass
931+
else:
932+
final_output.append(out_shape[gather_index])
933+
934+
if not final_output:
935+
return out
936+
return _op.reshape(out, newshape=tuple(final_output))
937+
795938
def convert_zeros_like(self, op):
796939
"""Convert TFLite ZEROS LIKE"""
797940
try:

tests/python/frontend/tflite/test_forward.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,74 @@ def test_forward_slice():
244244
_test_slice(np.arange(8, dtype=np.int32).reshape((2, 4)), begin=[0, 1], size=[-1, -1])
245245
_test_slice(np.arange(5, dtype=np.int32).reshape((5, )), begin=[4], size=[-1])
246246

247+
#######################################################################
248+
# Gather
249+
# ------
250+
251+
def _test_gather(dshape, indices, axis, dtype):
252+
""" One iteration of Gather """
253+
data = np.random.uniform(1, 10, size=dshape).astype(dtype)
254+
indices = np.asarray(indices).astype('int32')
255+
256+
with tf.Graph().as_default():
257+
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
258+
out = array_ops.gather(in_data, indices, axis=axis)
259+
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
260+
261+
#Test quantized input
262+
data = np.random.uniform(1, 10, size=dshape).astype(np.uint8)
263+
with tf.Graph().as_default():
264+
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="in_data")
265+
out = array_ops.gather(in_data, indices, axis=axis)
266+
compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], quantized=True)
267+
268+
def test_forward_gather():
269+
""" GATHER """
270+
_test_gather((4,), [1], 0, 'float32')
271+
_test_gather((1, 4), [0], 0, 'int32')
272+
_test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32')
273+
_test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'int32')
274+
_test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32')
275+
_test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'float32')
276+
_test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32')
277+
_test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32')
278+
_test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32')
279+
280+
#######################################################################
281+
# StridedSlice
282+
# ------------
283+
284+
def _test_stridedslice(ip_shape, begin, end, stride, dtype,
285+
begin_mask=0, end_mask=0, new_axis_mask=0,
286+
shrink_axis_mask=0, ellipsis_mask=0):
287+
""" One iteration of a Stridedslice """
288+
data = np.random.uniform(size=ip_shape).astype(dtype)
289+
with tf.Graph().as_default():
290+
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
291+
out = array_ops.strided_slice(in_data, begin, end, stride,
292+
begin_mask=begin_mask,
293+
end_mask=end_mask, new_axis_mask=new_axis_mask,
294+
shrink_axis_mask=shrink_axis_mask,
295+
ellipsis_mask=ellipsis_mask)
296+
compare_tflite_with_tvm(data, 'in_data:0', [in_data], [out])
297+
298+
#Test with quantized inputs
299+
data = np.random.uniform(size=ip_shape).astype(np.uint8)
300+
with tf.Graph().as_default():
301+
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
302+
out = array_ops.strided_slice(in_data, begin, end, stride,
303+
begin_mask=begin_mask,
304+
end_mask=end_mask, new_axis_mask=new_axis_mask,
305+
shrink_axis_mask=shrink_axis_mask,
306+
ellipsis_mask=ellipsis_mask)
307+
compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], quantized=True)
308+
309+
def test_forward_stridedslice():
310+
'''test StridedSlice'''
311+
_test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1)
312+
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32')
313+
_test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=1)
314+
247315
#######################################################################
248316
# transpose
249317
# ---------
@@ -1495,6 +1563,8 @@ def test_forward_mediapipe_hand_landmark():
14951563
test_all_resize()
14961564
test_forward_squeeze()
14971565
test_forward_slice()
1566+
test_forward_gather()
1567+
test_forward_stridedslice()
14981568

14991569
# NN
15001570
test_forward_convolution()

0 commit comments

Comments
 (0)