Skip to content

Commit e13576b

Browse files
committed
[FRONTEND][TFLITE]Gather, StridedSlice op added
1 parent 1b8522e commit e13576b

File tree

2 files changed

+207
-0
lines changed

2 files changed

+207
-0
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ def __init__(self, model, subgraph, exp_tab):
112112
'PRELU': self.convert_prelu,
113113
'TRANSPOSE_CONV': self.convert_transpose_conv,
114114
'SQUARED_DIFFERENCE': self.convert_squared_difference,
115+
'GATHER': self.convert_gather,
116+
'STRIDED_SLICE': self.convert_strided_slice,
115117
}
116118

117119
def check_unsupported_ops(self):
@@ -747,6 +749,158 @@ def convert_squared_difference(self, op):
747749
out = _op.power(difference, relay.const(2, exp_type))
748750
return out
749751

752+
def convert_gather(self, op):
753+
"""Method to Convert TFLite Gather operator"""
754+
# Check if the input tensor is quantized, call QNN op
755+
if self.is_quantized(op):
756+
raise tvm.error.OpNotImplemented(
757+
'TFlite quantized gather operator is not supported yet.')
758+
input_tensors = self.get_input_tensors(op)
759+
760+
try:
761+
from tflite.BuiltinOptions import BuiltinOptions
762+
from tflite.GatherOptions import GatherOptions
763+
from tflite.TensorType import TensorType
764+
except ImportError:
765+
raise ImportError("The tflite package must be installed")
766+
767+
assert op.BuiltinOptionsType() == BuiltinOptions.GatherOptions
768+
op_options = op.BuiltinOptions()
769+
gather_options = GatherOptions()
770+
gather_options.Init(op_options.Bytes, op_options.Pos)
771+
axis = gather_options.Axis()
772+
773+
data = self.get_expr(input_tensors[0].tensor_idx)
774+
775+
indices = input_tensors[1]
776+
indices_type = indices.tensor.Type()
777+
778+
assert indices_type in (TensorType.INT32, TensorType.INT64)
779+
indices_type_str = self.get_tensor_type_str(indices_type)
780+
indices = self.exp_tab.new_const(self.get_tensor_value(indices),
781+
dtype=indices_type_str)
782+
out = _op.take(data, indices, axis=axis)
783+
return out
784+
785+
def convert_strided_slice(self, op):
786+
"""Method to Convert TFLite Strided Slice operator"""
787+
# Check if the input tensor is quantized, call QNN op
788+
if self.is_quantized(op):
789+
raise tvm.error.OpNotImplemented(
790+
'TFlite quantized strided slice operator is not supported yet.')
791+
input_tensors = self.get_input_tensors(op)
792+
793+
try:
794+
from tflite.BuiltinOptions import BuiltinOptions
795+
from tflite.StridedSliceOptions import StridedSliceOptions
796+
except ImportError:
797+
raise ImportError("The tflite package must be installed")
798+
799+
data_expr = self.get_expr(input_tensors[0].tensor_idx)
800+
801+
begin = list(self.get_tensor_value(input_tensors[1]))
802+
end = list(self.get_tensor_value(input_tensors[2]))
803+
stride = list(self.get_tensor_value(input_tensors[3]))
804+
805+
assert op.BuiltinOptionsType() == BuiltinOptions.StridedSliceOptions
806+
op_options = op.BuiltinOptions()
807+
options = StridedSliceOptions()
808+
options.Init(op_options.Bytes, op_options.Pos)
809+
begin_mask = options.BeginMask()
810+
end_mask = options.EndMask()
811+
ellipsis_mask = options.EllipsisMask()
812+
new_axis_mask = options.NewAxisMask()
813+
shrink_axis_mask = options.ShrinkAxisMask()
814+
815+
data_shape = list(input_tensors[0].tensor.ShapeAsNumpy())
816+
817+
data_dim = len(data_shape)
818+
stride_dim = len(list(input_tensors[3].tensor.ShapeAsNumpy()))
819+
820+
def _transform_mask(stride_dim, ellipsis_mask):
821+
"""Handle mask inputs to create new begin, end, stride and output shape"""
822+
m_begin = [0] * data_dim
823+
m_end = [0] * data_dim
824+
m_stride = [0] * data_dim
825+
fshape_indices = []
826+
#Count new axis after ellipsis_mask, consider while applying ellipsis_mask.
827+
ellipsis_seen = False
828+
new_axes_after_ellipsis = 0
829+
for i in range(stride_dim):
830+
mask = 1 << i
831+
if ellipsis_seen and (mask & new_axis_mask) != 0:
832+
new_axes_after_ellipsis += 1
833+
if (mask & ellipsis_mask) != 0:
834+
ellipsis_seen = True
835+
if not ellipsis_seen:
836+
#Used later for extending the stride attributes in the below loop.
837+
ellipsis_mask |= (1 << stride_dim)
838+
stride_dim += 1
839+
final_index = 0
840+
for index in range(stride_dim):
841+
mask = 1 << index
842+
if mask & ellipsis_mask:
843+
#Identify the end index for applying ellipsis_mask
844+
to_index = min(((data_dim - (stride_dim-index)) + 1 \
845+
+ new_axes_after_ellipsis), data_dim)
846+
for i in range(final_index, to_index):
847+
m_begin[final_index] = 0
848+
m_end[final_index] = data_shape[final_index]
849+
m_stride[final_index] = 1
850+
fshape_indices.append(final_index)
851+
final_index += 1
852+
elif mask &new_axis_mask:
853+
fshape_indices.append(-1)
854+
elif not mask & new_axis_mask:
855+
if final_index == len(m_begin):
856+
break
857+
if mask & begin_mask:
858+
m_begin[final_index] = data_shape[final_index] \
859+
if stride[index] < 0 else 0
860+
elif begin[index]:
861+
m_begin[final_index] = begin[index]
862+
if mask & end_mask:
863+
m_end[final_index] = 0 if stride[index] < 0 \
864+
else data_shape[final_index]
865+
elif end[index]:
866+
m_end[final_index] = end[index]
867+
m_stride[final_index] = stride[index]
868+
if mask & shrink_axis_mask:
869+
#Tensorflow make axis with shrink_axis_mask as dimension 1
870+
m_begin[final_index] = data_shape[final_index] + begin[index] \
871+
if begin[index] < 0 else begin[index]
872+
m_end[final_index] = begin[index] + 1
873+
m_stride[final_index] = 1
874+
fshape_indices.append(-2)
875+
else:
876+
fshape_indices.append(final_index)
877+
878+
final_index += 1
879+
return m_begin, m_end, m_stride, fshape_indices
880+
881+
fshape_indices = None
882+
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
883+
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
884+
885+
out = _op.strided_slice(data_expr, begin=begin, end=end, strides=stride)
886+
out_shape = _infer_shape(out)
887+
if not fshape_indices:
888+
fshape_indices = range(len(out_shape))
889+
890+
#Create final output shape.
891+
final_output = []
892+
for gather_index in fshape_indices:
893+
if gather_index == -1:
894+
final_output.append(1)
895+
elif gather_index == -2:
896+
pass
897+
else:
898+
final_output.append(out_shape[gather_index])
899+
900+
if not final_output:
901+
return out
902+
return _op.reshape(out, newshape=tuple(final_output))
903+
750904
def convert_zeros_like(self, op):
751905
"""Convert TFLite ZEROS LIKE"""
752906
try:

tests/python/frontend/tflite/test_forward.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,57 @@ def test_forward_slice():
244244
_test_slice(np.arange(8, dtype=np.int32).reshape((2, 4)), begin=[0, 1], size=[-1, -1])
245245
_test_slice(np.arange(5, dtype=np.int32).reshape((5, )), begin=[4], size=[-1])
246246

247+
#######################################################################
248+
# Gather
249+
# ------
250+
251+
def _test_gather(dshape, indices, axis, dtype):
252+
""" One iteration of Gather """
253+
data = np.random.uniform(1, 10, size=dshape).astype(dtype)
254+
indices = np.asarray(indices).astype('int32')
255+
256+
with tf.Graph().as_default():
257+
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
258+
out = array_ops.gather(in_data, indices, axis=axis)
259+
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
260+
261+
def test_forward_gather():
262+
""" GATHER """
263+
_test_gather((4,), [1], None, 'float32')
264+
_test_gather((1, 4), [0], 0, 'int32')
265+
_test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32')
266+
_test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'int32')
267+
_test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32')
268+
_test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'float32')
269+
_test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32')
270+
_test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32')
271+
_test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32')
272+
273+
#######################################################################
274+
# StridedSlice
275+
# ------------
276+
277+
def _test_stridedslice(ip_shape, begin, end, stride, dtype,
278+
begin_mask=0, end_mask=0, new_axis_mask=0,
279+
shrink_axis_mask=0, ellipsis_mask=0):
280+
""" One iteration of a Stridedslice """
281+
data = np.random.uniform(size=ip_shape).astype(dtype)
282+
283+
with tf.Graph().as_default():
284+
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
285+
out = array_ops.strided_slice(in_data, begin, end, stride,
286+
begin_mask=begin_mask,
287+
end_mask=end_mask, new_axis_mask=new_axis_mask,
288+
shrink_axis_mask=shrink_axis_mask,
289+
ellipsis_mask=ellipsis_mask)
290+
compare_tflite_with_tvm(data, 'in_data:0', [in_data], [out])
291+
292+
def test_forward_stridedslice():
293+
'''test StridedSlice'''
294+
_test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1)
295+
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32')
296+
_test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=1)
297+
247298
#######################################################################
248299
# transpose
249300
# ---------
@@ -1456,6 +1507,8 @@ def test_forward_mediapipe_hand_landmark():
14561507
test_all_resize()
14571508
test_forward_squeeze()
14581509
test_forward_slice()
1510+
test_forward_gather()
1511+
test_forward_stridedslice()
14591512

14601513
# NN
14611514
test_forward_convolution()

0 commit comments

Comments
 (0)