Skip to content

Commit 47ebc7b

Browse files
committed
[FRONTEND][TFLITE]Gather, StridedSlice op added
1 parent 1b8522e commit 47ebc7b

File tree

2 files changed

+205
-0
lines changed

2 files changed

+205
-0
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ def __init__(self, model, subgraph, exp_tab):
112112
'PRELU': self.convert_prelu,
113113
'TRANSPOSE_CONV': self.convert_transpose_conv,
114114
'SQUARED_DIFFERENCE': self.convert_squared_difference,
115+
'GATHER': self.convert_gather,
116+
'STRIDED_SLICE': self.convert_strided_slice,
115117
}
116118

117119
def check_unsupported_ops(self):
@@ -747,6 +749,156 @@ def convert_squared_difference(self, op):
747749
out = _op.power(difference, relay.const(2, exp_type))
748750
return out
749751

752+
def convert_gather(self, op):
753+
# Check if the input tensor is quantized, call QNN op
754+
if self.is_quantized(op):
755+
raise tvm.error.OpNotImplemented(
756+
'TFlite quantized gather operator is not supported yet.')
757+
input_tensors = self.get_input_tensors(op)
758+
759+
try:
760+
from tflite.BuiltinOptions import BuiltinOptions
761+
from tflite.GatherOptions import GatherOptions
762+
from tflite.TensorType import TensorType
763+
except ImportError:
764+
raise ImportError("The tflite package must be installed")
765+
766+
assert op.BuiltinOptionsType() == BuiltinOptions.GatherOptions
767+
op_options = op.BuiltinOptions()
768+
gather_options = GatherOptions()
769+
gather_options.Init(op_options.Bytes, op_options.Pos)
770+
axis = gather_options.Axis()
771+
772+
data = self.get_expr(input_tensors[0].tensor_idx)
773+
774+
indices = input_tensors[1]
775+
indices_type = indices.tensor.Type()
776+
777+
assert indices_type in (TensorType.INT32, TensorType.INT64)
778+
indices_type_str = self.get_tensor_type_str(indices_type)
779+
indices = self.exp_tab.new_const(self.get_tensor_value(indices),
780+
dtype=indices_type_str)
781+
out = _op.take(data, indices, axis=axis)
782+
return out
783+
784+
def convert_strided_slice(self, op):
785+
# Check if the input tensor is quantized, call QNN op
786+
if self.is_quantized(op):
787+
raise tvm.error.OpNotImplemented(
788+
'TFlite quantized strided slice operator is not supported yet.')
789+
input_tensors = self.get_input_tensors(op)
790+
791+
try:
792+
from tflite.BuiltinOptions import BuiltinOptions
793+
from tflite.StridedSliceOptions import StridedSliceOptions
794+
except ImportError:
795+
raise ImportError("The tflite package must be installed")
796+
797+
data_expr = self.get_expr(input_tensors[0].tensor_idx)
798+
799+
begin = list(self.get_tensor_value(input_tensors[1]))
800+
end = list(self.get_tensor_value(input_tensors[2]))
801+
stride = list(self.get_tensor_value(input_tensors[3]))
802+
803+
assert op.BuiltinOptionsType() == BuiltinOptions.StridedSliceOptions
804+
op_options = op.BuiltinOptions()
805+
options = StridedSliceOptions()
806+
options.Init(op_options.Bytes, op_options.Pos)
807+
begin_mask = options.BeginMask()
808+
end_mask = options.EndMask()
809+
ellipsis_mask = options.EllipsisMask()
810+
new_axis_mask = options.NewAxisMask()
811+
shrink_axis_mask = options.ShrinkAxisMask()
812+
813+
data_shape = list(input_tensors[0].tensor.ShapeAsNumpy())
814+
815+
data_dim = len(data_shape)
816+
stride_dim = len(list(input_tensors[3].tensor.ShapeAsNumpy()))
817+
818+
def _transform_mask(stride_dim, ellipsis_mask):
819+
"""Handle mask inputs to create new begin, end, stride and output shape"""
820+
m_begin = [0] * data_dim
821+
m_end = [0] * data_dim
822+
m_stride = [0] * data_dim
823+
fshape_indices = []
824+
#Count new axis after ellipsis_mask, consider while applying ellipsis_mask.
825+
ellipsis_seen = False
826+
new_axes_after_ellipsis = 0
827+
for i in range(stride_dim):
828+
mask = 1 << i
829+
if ellipsis_seen and (mask & new_axis_mask) != 0:
830+
new_axes_after_ellipsis += 1
831+
if (mask & ellipsis_mask) != 0:
832+
ellipsis_seen = True
833+
if not ellipsis_seen:
834+
#Used later for extending the stride attributes in the below loop.
835+
ellipsis_mask |= (1 << stride_dim)
836+
stride_dim += 1
837+
final_index = 0
838+
for index in range(stride_dim):
839+
mask = 1 << index
840+
if mask & ellipsis_mask:
841+
#Identify the end index for applying ellipsis_mask
842+
to_index = min(((data_dim - (stride_dim-index)) + 1 \
843+
+ new_axes_after_ellipsis), data_dim)
844+
for i in range(final_index, to_index):
845+
m_begin[final_index] = 0
846+
m_end[final_index] = data_shape[final_index]
847+
m_stride[final_index] = 1
848+
fshape_indices.append(final_index)
849+
final_index += 1
850+
elif mask &new_axis_mask:
851+
fshape_indices.append(-1)
852+
elif not mask & new_axis_mask:
853+
if final_index == len(m_begin):
854+
break
855+
if mask & begin_mask:
856+
m_begin[final_index] = data_shape[final_index] \
857+
if stride[index] < 0 else 0
858+
elif begin[index]:
859+
m_begin[final_index] = begin[index]
860+
if mask & end_mask:
861+
m_end[final_index] = 0 if stride[index] < 0 \
862+
else data_shape[final_index]
863+
elif end[index]:
864+
m_end[final_index] = end[index]
865+
m_stride[final_index] = stride[index]
866+
if mask & shrink_axis_mask:
867+
#Tensorflow make axis with shrink_axis_mask as dimension 1
868+
m_begin[final_index] = data_shape[final_index] + begin[index] \
869+
if begin[index] < 0 else begin[index]
870+
m_end[final_index] = begin[index] + 1
871+
m_stride[final_index] = 1
872+
fshape_indices.append(-2)
873+
else:
874+
fshape_indices.append(final_index)
875+
876+
final_index += 1
877+
return m_begin, m_end, m_stride, fshape_indices
878+
879+
fshape_indices = None
880+
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
881+
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
882+
883+
out = _op.strided_slice(data_expr, begin=begin, end=end, strides=stride)
884+
out_shape = _infer_shape(out)
885+
if not fshape_indices:
886+
fshape_indices = range(len(out_shape))
887+
888+
#Create final output shape.
889+
final_output = []
890+
for gather_index in fshape_indices:
891+
if gather_index == -1:
892+
final_output.append(1)
893+
elif gather_index == -2:
894+
pass
895+
else:
896+
final_output.append(out_shape[gather_index])
897+
898+
if not final_output:
899+
return out
900+
return _op.reshape(out, newshape=tuple(final_output))
901+
750902
def convert_zeros_like(self, op):
751903
"""Convert TFLite ZEROS LIKE"""
752904
try:

tests/python/frontend/tflite/test_forward.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,57 @@ def test_forward_slice():
244244
_test_slice(np.arange(8, dtype=np.int32).reshape((2, 4)), begin=[0, 1], size=[-1, -1])
245245
_test_slice(np.arange(5, dtype=np.int32).reshape((5, )), begin=[4], size=[-1])
246246

247+
#######################################################################
248+
# Gather
249+
# ------
250+
251+
def _test_gather(dshape, indices, axis, dtype):
252+
""" One iteration of Gather """
253+
data = np.random.uniform(1, 10, size=dshape).astype(dtype)
254+
indices = np.asarray(indices).astype('int32')
255+
256+
with tf.Graph().as_default():
257+
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
258+
out = array_ops.gather(in_data, indices, axis=axis)
259+
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
260+
261+
def test_forward_gather():
262+
""" GATHER """
263+
_test_gather((4,), [1], None, 'float32')
264+
_test_gather((1, 4), [0], 0, 'int32')
265+
_test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32')
266+
_test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'int32')
267+
_test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32')
268+
_test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'float32')
269+
_test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32')
270+
_test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32')
271+
_test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32')
272+
273+
#######################################################################
274+
# StridedSlice
275+
# ------------
276+
277+
def _test_stridedslice(ip_shape, begin, end, stride, dtype,
278+
begin_mask=0, end_mask=0, new_axis_mask=0,
279+
shrink_axis_mask=0, ellipsis_mask=0):
280+
""" One iteration of a Stridedslice """
281+
data = np.random.uniform(size=ip_shape).astype(dtype)
282+
283+
with tf.Graph().as_default():
284+
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
285+
out = array_ops.strided_slice(in_data, begin, end, stride,
286+
begin_mask=begin_mask,
287+
end_mask=end_mask, new_axis_mask=new_axis_mask,
288+
shrink_axis_mask=shrink_axis_mask,
289+
ellipsis_mask=ellipsis_mask)
290+
compare_tflite_with_tvm(data, 'in_data:0', [in_data], [out])
291+
292+
def test_forward_stridedslice():
293+
'''test StridedSlice'''
294+
_test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1)
295+
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32')
296+
_test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=1)
297+
247298
#######################################################################
248299
# transpose
249300
# ---------
@@ -1456,6 +1507,8 @@ def test_forward_mediapipe_hand_landmark():
14561507
test_all_resize()
14571508
test_forward_squeeze()
14581509
test_forward_slice()
1510+
test_forward_gather()
1511+
test_forward_stridedslice()
14591512

14601513
# NN
14611514
test_forward_convolution()

0 commit comments

Comments
 (0)