Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND][TFLITE]Gather, StridedSlice op support added #4788

Merged
merged 2 commits into from
Apr 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 214 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=invalid-name, unused-argument, too-many-lines, import-outside-toplevel
"""Tensorflow lite frontend."""
import math
import itertools
import numpy as np
import tvm
from tvm.ir import IRModule
Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(self, model, subgraph, exp_tab):
'FLOOR_MOD': self.convert_floor_mod,
'FLOOR': self.convert_floor,
'FULLY_CONNECTED': self.convert_fully_connected,
'GATHER': self.convert_gather,
'GREATER_EQUAL': self.convert_greater_equal,
'GREATER': self.convert_greater,
'L2_NORMALIZATION': self.convert_l2_normalization,
Expand Down Expand Up @@ -124,6 +126,7 @@ def __init__(self, model, subgraph, exp_tab):
'SQUARE': self.convert_square,
'SQUARED_DIFFERENCE': self.convert_squared_difference,
'SQUEEZE': self.convert_squeeze,
'STRIDED_SLICE': self.convert_strided_slice,
'SUB': self.convert_sub,
'SUM': self._convert_reduce_sum,
'TAN': self.convert_tan,
Expand Down Expand Up @@ -1014,6 +1017,217 @@ def convert_logical_or(self, op):
"""Convert tflite LOGICAL_OR"""
return self._convert_logical_binary(_op.logical_or, op)

def convert_gather(self, op):
"""Method to Convert TFLite GATHER operator"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.GatherOptions import GatherOptions
from tflite.TensorType import TensorType
except ImportError:
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
siju-samuel marked this conversation as resolved.
Show resolved Hide resolved
assert len(input_tensors) == 2, "input tensors length should be 2"

data = self.get_expr(input_tensors[0].tensor_idx)

indices = input_tensors[1]
indices_type = indices.tensor.Type()
assert indices_type in (TensorType.INT32, TensorType.INT64)
indices_type_str = self.get_tensor_type_str(indices_type)
indices = self.exp_tab.new_const(self.get_tensor_value(indices),
dtype=indices_type_str)

assert op.BuiltinOptionsType() == BuiltinOptions.GatherOptions
op_options = op.BuiltinOptions()
gather_options = GatherOptions()
gather_options.Init(op_options.Bytes, op_options.Pos)
axis = gather_options.Axis()

# Check the indices are with in bounds.
data_shape = list(input_tensors[0].tensor.ShapeAsNumpy())
data_dim = len(data_shape)

axis_n = axis
if axis_n < 0:
axis_n += axis_n + data_dim
assert axis_n >= 0, "Axis out of bounds"
assert axis_n < data_dim, "Axis out of bounds"

indices_val = self.get_tensor_value(input_tensors[1])
indices_shape = list(indices_val.shape)
indices_len = len(indices_shape)

out_shape = []
for i in range(data_dim):
if axis_n == i:
for j in range(indices_len):
out_shape.append(indices_shape[j])
else:
out_shape.append(data_shape[i])

loopover = [range(s) for s in out_shape]
for idx in list(itertools.product(*loopover)):
indices_position = [idx[j] for j in range(axis_n, axis_n+indices_len)]

real_indices = [idx[j] for j in range(axis_n)]
real_indices.append(indices_val[tuple(indices_position)])
real_indices.extend([idx[j] for j in range(axis_n + indices_len, len(idx))])
for r, d in zip(real_indices, data_shape):
if r >= d:
raise ValueError("TFLite out of bound indices are not supported.")

# Use mode 'fast' since indices are already checked within bounds.
out = _op.take(data, indices, axis=axis, mode="fast")
return out

siju-samuel marked this conversation as resolved.
Show resolved Hide resolved
def convert_strided_slice(self, op):
"""Method to Convert TFLite STRIDED_SLICE operator.
NOTE: Eventhough tensorflow supports begin_mask, end_mask, ellipsis_mask, new_axis_mask
and shrink_axis_mask, tflite doesn't support these and expect these values to be zero.
But in future, they may open up the mask implementation, so kept the implementation
same as tensorflow.

This op extracts a slice of size (end - begin) / stride from the given input tensor.
Starting at the location specified by begin the slice continues by adding stride to the
index until all dimensions are not less than end. Note that a stride can be negative,
which causes a reverse slice.

For slice input[val0, val1, ..., valn], begin/end/strides will be vectors of length n.

In each mask field(begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask)
the ith bit will correspond to the ith val.

If the ith bit of begin_mask is set, begin[i] is ignored and the fullest possible range
in that dimension is used instead.

If the ith bit of ellipsis_mask is set, as many unspecified dimensions as needed will be
inserted between other dimensions. Only one non-zero bit is allowed in ellipsis_mask.

If the ith bit of new_axis_mask is set, then begin, end, and stride are ignored and a
new length 1 dimension is added at this point in the output tensor.

If the ith bit of shrink_axis_mask is set, it implies that the ith specification shrinks
the dimensionality by 1, taking on the value at index begin[i]. end[i] and strides[i]
are ignored in this case.
begin and end are zero-indexed. strides entries must be non-zero.

TVM Relay implementation of doesn't support mask, so the mask values are processed in
this function and begin/end/strides are updated accordingly. If any mask is present, and
since tvm doesn't support mask computation directly, the output need a final reshape.
"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.StridedSliceOptions import StridedSliceOptions
except ImportError:
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
siju-samuel marked this conversation as resolved.
Show resolved Hide resolved
assert len(input_tensors) == 4, "input tensors length should be 4"

data_expr = self.get_expr(input_tensors[0].tensor_idx)

begin = list(self.get_tensor_value(input_tensors[1]))
end = list(self.get_tensor_value(input_tensors[2]))
stride = list(self.get_tensor_value(input_tensors[3]))

assert op.BuiltinOptionsType() == BuiltinOptions.StridedSliceOptions
op_options = op.BuiltinOptions()
options = StridedSliceOptions()
options.Init(op_options.Bytes, op_options.Pos)
begin_mask = options.BeginMask()
end_mask = options.EndMask()
ellipsis_mask = options.EllipsisMask()
new_axis_mask = options.NewAxisMask()
shrink_axis_mask = options.ShrinkAxisMask()

data_shape = list(input_tensors[0].tensor.ShapeAsNumpy())
data_dim = len(data_shape)
stride_dim = len(stride)
def _transform_mask(stride_dim, ellipsis_mask):
"""Handle mask inputs to create new begin, end, stride and output shape"""
m_begin = [0] * data_dim
m_end = [0] * data_dim
m_stride = [0] * data_dim
fshape_indices = []
#Count new axis after ellipsis_mask, consider while applying ellipsis_mask.
ellipsis_seen = False
new_axes_after_ellipsis = 0
for i in range(stride_dim):
mask = 1 << i
if ellipsis_seen and (mask & new_axis_mask) != 0:
new_axes_after_ellipsis += 1
if (mask & ellipsis_mask) != 0:
ellipsis_seen = True
if not ellipsis_seen:
#Used later for extending the stride attributes in the below loop.
ellipsis_mask |= (1 << stride_dim)
stride_dim += 1
final_index = 0
for index in range(stride_dim):
mask = 1 << index
if mask & ellipsis_mask:
#Identify the end index for applying ellipsis_mask
to_index = min(((data_dim - (stride_dim-index)) + 1 \
+ new_axes_after_ellipsis), data_dim)
for i in range(final_index, to_index):
m_begin[final_index] = 0
m_end[final_index] = data_shape[final_index]
m_stride[final_index] = 1
fshape_indices.append(final_index)
final_index += 1
elif mask &new_axis_mask:
fshape_indices.append(-1)
elif not mask & new_axis_mask:
if final_index == len(m_begin):
break
if mask & begin_mask:
m_begin[final_index] = data_shape[final_index] \
if stride[index] < 0 else 0
elif begin[index]:
m_begin[final_index] = begin[index]
if mask & end_mask:
m_end[final_index] = 0 if stride[index] < 0 \
else data_shape[final_index]
elif end[index]:
m_end[final_index] = end[index]
m_stride[final_index] = stride[index]
if mask & shrink_axis_mask:
#Tensorflow make axis with shrink_axis_mask as dimension 1
m_begin[final_index] = data_shape[final_index] + begin[index] \
if begin[index] < 0 else begin[index]
m_end[final_index] = begin[index] + 1
m_stride[final_index] = 1
fshape_indices.append(-2)
else:
fshape_indices.append(final_index)

final_index += 1
return m_begin, m_end, m_stride, fshape_indices

fshape_indices = None
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)

out = _op.strided_slice(data_expr, begin=begin, end=end, strides=stride)
out_shape = _infer_shape(out)
if not fshape_indices:
fshape_indices = range(len(out_shape))

#Create final output shape.
final_output = []
for gather_index in fshape_indices:
if gather_index == -1:
final_output.append(1)
elif gather_index == -2:
pass
else:
final_output.append(out_shape[gather_index])

if not final_output:
return out
return _op.reshape(out, newshape=tuple(final_output))

def convert_zeros_like(self, op):
"""Convert TFLite ZEROS LIKE"""
try:
Expand Down
75 changes: 75 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,79 @@ def test_forward_topk():
_test_topk((3, 5, 7), 3)
_test_topk((3, 5, 7), 3)

#######################################################################
# Gather
# ------

def _test_gather(dshape, indices, axis, dtype, quantized=False, oob=False):
""" One iteration of Gather """
indices = np.asarray(indices).astype('int32')
data = np.random.uniform(1, 10, size=dshape)
data = data.astype(np.uint8) if quantized else data.astype(dtype)
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="in_data")
if axis:
out = array_ops.gather(in_data, indices, axis=axis)
else:
out = array_ops.gather(in_data, indices) #tflite conversion fails for None axis
input_range = {'in_data': (-100, 100)} if quantized else None
try:
compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out],
quantized=quantized, input_range=input_range)
except ValueError as e:
if not oob:
siju-samuel marked this conversation as resolved.
Show resolved Hide resolved
raise e
except Exception as e:
raise e

def test_forward_gather():
siju-samuel marked this conversation as resolved.
Show resolved Hide resolved
""" GATHER """
for quantized in [False, True]:
_test_gather((4,), [1], 0, 'float32', quantized)
_test_gather((4,), [1], None, 'int32', quantized)
_test_gather((1, 4), [0], 0, 'int32', quantized)
_test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32', quantized)
_test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32', quantized)
_test_gather((2, 2), [[[1, 0], [0, 1]]], None, 'float32', quantized)
_test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32', quantized)
_test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32', quantized)
_test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32', quantized)
_test_gather((3, 3, 3), [[[2, 1]]], -1, 'int32', quantized)
_test_gather((4,), [16], 0, 'float32', quantized, oob=True)
_test_gather((1, 3, 3), [12], 0, 'int32', quantized, oob=True)
_test_gather((1, 3, 3), [20], 1, 'float32', quantized, oob=True)
_test_gather((1, 3, 3), [20, 20], 2, 'float32', quantized, oob=True)

#######################################################################
# StridedSlice
# ------------

def _test_stridedslice(ip_shape, begin, end, stride, dtype,
begin_mask=0, end_mask=0, new_axis_mask=0,
shrink_axis_mask=0, ellipsis_mask=0, quantized=False):
""" One iteration of a Stridedslice """
data = np.random.uniform(size=ip_shape).astype(dtype)
data = data.astype(np.uint8) if quantized else data.astype(dtype)
with tf.Graph().as_default():
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
out = array_ops.strided_slice(in_data, begin, end, stride,
begin_mask=begin_mask,
end_mask=end_mask,
new_axis_mask=new_axis_mask,
shrink_axis_mask=shrink_axis_mask,
ellipsis_mask=ellipsis_mask)
input_range = {'in_data': (-100, 100)} if quantized else None
compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], quantized=quantized,
input_range=input_range)

def test_forward_stridedslice():
'''test StridedSlice'''
siju-samuel marked this conversation as resolved.
Show resolved Hide resolved
for quantized in [False, True]:
_test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1, quantized=quantized)
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32', quantized=quantized)
_test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=0, quantized=quantized)
_test_stridedslice((4, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=2, quantized=quantized)

#######################################################################
# transpose
# ---------
Expand Down Expand Up @@ -1794,6 +1867,8 @@ def test_forward_mediapipe_hand_landmark():
test_forward_squeeze()
test_forward_slice()
test_forward_topk()
test_forward_gather()
test_forward_stridedslice()
test_forward_depthtospace()
test_forward_spacetodepth()

Expand Down