diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index a272fae4c3990..2a6f89fa916ff 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -18,6 +18,7 @@ """Tensorflow lite frontend.""" import math +import itertools import numpy as np import tvm from tvm.ir import IRModule @@ -895,6 +896,8 @@ def convert_gather(self, op): raise ImportError("The tflite package must be installed") input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + data = self.get_expr(input_tensors[0].tensor_idx) indices = input_tensors[1] @@ -910,11 +913,78 @@ def convert_gather(self, op): gather_options.Init(op_options.Bytes, op_options.Pos) axis = gather_options.Axis() - out = _op.take(data, indices, axis=axis) + # Check the indices are with in bounds. + data_shape = list(input_tensors[0].tensor.ShapeAsNumpy()) + data_dim = len(data_shape) + + axis_n = axis + if axis_n < 0: + axis_n += axis_n + data_dim + assert axis_n >= 0, "Axis out of bounds" + assert axis_n < data_dim, "Axis out of bounds" + + indices_val = self.get_tensor_value(input_tensors[1]) + indices_shape = list(indices_val.shape) + indices_len = len(indices_shape) + + out_shape = [] + for i in range(data_dim): + if axis_n == i: + for j in range(indices_len): + out_shape.append(indices_shape[j]) + else: + out_shape.append(data_shape[i]) + + loopover = [range(s) for s in out_shape] + for idx in list(itertools.product(*loopover)): + indices_position = [idx[j] for j in range(axis_n, axis_n+indices_len)] + + real_indices = [idx[j] for j in range(axis_n)] + real_indices.append(indices_val[tuple(indices_position)]) + real_indices.extend([idx[j] for j in range(axis_n + indices_len, len(idx))]) + for r, d in zip(real_indices, data_shape): + if r >= d: + raise ValueError("TFLite out of bound indices are not supported.") + + # Use mode 'fast' since indices are already checked within bounds. + out = _op.take(data, indices, axis=axis, mode="fast") return out def convert_strided_slice(self, op): - """Method to Convert TFLite STRIDED_SLICE operator""" + """Method to Convert TFLite STRIDED_SLICE operator. + NOTE: Eventhough tensorflow supports begin_mask, end_mask, ellipsis_mask, new_axis_mask + and shrink_axis_mask, tflite doesn't support these and expect these values to be zero. + But in future, they may open up the mask implementation, so kept the implementation + same as tensorflow. + + This op extracts a slice of size (end - begin) / stride from the given input tensor. + Starting at the location specified by begin the slice continues by adding stride to the + index until all dimensions are not less than end. Note that a stride can be negative, + which causes a reverse slice. + + For slice input[val0, val1, ..., valn], begin/end/strides will be vectors of length n. + + In each mask field(begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask) + the ith bit will correspond to the ith val. + + If the ith bit of begin_mask is set, begin[i] is ignored and the fullest possible range + in that dimension is used instead. + + If the ith bit of ellipsis_mask is set, as many unspecified dimensions as needed will be + inserted between other dimensions. Only one non-zero bit is allowed in ellipsis_mask. + + If the ith bit of new_axis_mask is set, then begin, end, and stride are ignored and a + new length 1 dimension is added at this point in the output tensor. + + If the ith bit of shrink_axis_mask is set, it implies that the ith specification shrinks + the dimensionality by 1, taking on the value at index begin[i]. end[i] and strides[i] + are ignored in this case. + begin and end are zero-indexed. strides entries must be non-zero. + + TVM Relay implementation of doesn't support mask, so the mask values are processed in + this function and begin/end/strides are updated accordingly. If any mask is present, and + since tvm doesn't support mask computation directly, the output need a final reshape. + """ try: from tflite.BuiltinOptions import BuiltinOptions from tflite.StridedSliceOptions import StridedSliceOptions @@ -922,6 +992,8 @@ def convert_strided_slice(self, op): raise ImportError("The tflite package must be installed") input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 4, "input tensors length should be 4" + data_expr = self.get_expr(input_tensors[0].tensor_idx) begin = list(self.get_tensor_value(input_tensors[1])) @@ -940,8 +1012,7 @@ def convert_strided_slice(self, op): data_shape = list(input_tensors[0].tensor.ShapeAsNumpy()) data_dim = len(data_shape) - stride_dim = len(list(input_tensors[3].tensor.ShapeAsNumpy())) - + stride_dim = len(stride) def _transform_mask(stride_dim, ellipsis_mask): """Handle mask inputs to create new begin, end, stride and output shape""" m_begin = [0] * data_dim diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 0c0d156d79d57..d0add1f8f704d 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -271,34 +271,44 @@ def test_forward_slice(): # Gather # ------ -def _test_gather(dshape, indices, axis, dtype): +def _test_gather(dshape, indices, axis, dtype, quantized=False, oob=False): """ One iteration of Gather """ - data = np.random.uniform(1, 10, size=dshape).astype(dtype) indices = np.asarray(indices).astype('int32') - - with tf.Graph().as_default(): - in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) - out = array_ops.gather(in_data, indices, axis=axis) - compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) - - #Test quantized input - data = np.random.uniform(1, 10, size=dshape).astype(np.uint8) + data = np.random.uniform(1, 10, size=dshape) + data = data.astype(np.uint8) if quantized else data.astype(dtype) with tf.Graph().as_default(): in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="in_data") - out = array_ops.gather(in_data, indices, axis=axis) - compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], quantized=True) + if axis: + out = array_ops.gather(in_data, indices, axis=axis) + else: + out = array_ops.gather(in_data, indices) #tflite conversion fails for None axis + input_range = {'in_data': (-100, 100)} if quantized else None + try: + compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], + quantized=quantized, input_range=input_range) + except ValueError as e: + if not oob: + raise e + except Exception as e: + raise e def test_forward_gather(): """ GATHER """ - _test_gather((4,), [1], 0, 'float32') - _test_gather((1, 4), [0], 0, 'int32') - _test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32') - _test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'int32') - _test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32') - _test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'float32') - _test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32') - _test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32') - _test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32') + for quantized in [False, True]: + _test_gather((4,), [1], 0, 'float32', quantized) + _test_gather((4,), [1], None, 'int32', quantized) + _test_gather((1, 4), [0], 0, 'int32', quantized) + _test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32', quantized) + _test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32', quantized) + _test_gather((2, 2), [[[1, 0], [0, 1]]], None, 'float32', quantized) + _test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32', quantized) + _test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32', quantized) + _test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32', quantized) + _test_gather((3, 3, 3), [[[2, 1]]], -1, 'int32', quantized) + _test_gather((4,), [16], 0, 'float32', quantized, oob=True) + _test_gather((1, 3, 3), [12], 0, 'int32', quantized, oob=True) + _test_gather((1, 3, 3), [20], 1, 'float32', quantized, oob=True) + _test_gather((1, 3, 3), [20, 20], 2, 'float32', quantized, oob=True) ####################################################################### # StridedSlice @@ -306,34 +316,29 @@ def test_forward_gather(): def _test_stridedslice(ip_shape, begin, end, stride, dtype, begin_mask=0, end_mask=0, new_axis_mask=0, - shrink_axis_mask=0, ellipsis_mask=0): + shrink_axis_mask=0, ellipsis_mask=0, quantized=False): """ One iteration of a Stridedslice """ data = np.random.uniform(size=ip_shape).astype(dtype) + data = data.astype(np.uint8) if quantized else data.astype(dtype) with tf.Graph().as_default(): in_data = tf.placeholder(dtype, ip_shape, name="in_data") out = array_ops.strided_slice(in_data, begin, end, stride, begin_mask=begin_mask, - end_mask=end_mask, new_axis_mask=new_axis_mask, - shrink_axis_mask=shrink_axis_mask, - ellipsis_mask=ellipsis_mask) - compare_tflite_with_tvm(data, 'in_data:0', [in_data], [out]) - - #Test with quantized inputs - data = np.random.uniform(size=ip_shape).astype(np.uint8) - with tf.Graph().as_default(): - in_data = tf.placeholder(dtype, ip_shape, name="in_data") - out = array_ops.strided_slice(in_data, begin, end, stride, - begin_mask=begin_mask, - end_mask=end_mask, new_axis_mask=new_axis_mask, + end_mask=end_mask, + new_axis_mask=new_axis_mask, shrink_axis_mask=shrink_axis_mask, ellipsis_mask=ellipsis_mask) - compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], quantized=True) + input_range = {'in_data': (-100, 100)} if quantized else None + compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], quantized=quantized, + input_range=input_range) def test_forward_stridedslice(): '''test StridedSlice''' - _test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1) - _test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32') - _test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=1) + for quantized in [False, True]: + _test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1, quantized=quantized) + _test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32', quantized=quantized) + _test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=0, quantized=quantized) + _test_stridedslice((4, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=2, quantized=quantized) ####################################################################### # transpose