Skip to content

Commit

Permalink
[TFLITE]GATHER_ND (apache#5508)
Browse files Browse the repository at this point in the history
Signed-off-by: Dhruva Ray <dhruvaray@gmail.com>
  • Loading branch information
dhruvaray authored May 18, 2020
1 parent a8e4471 commit 8a63b7f
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
26 changes: 26 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(self, model, subgraph, exp_tab):
'FLOOR': self.convert_floor,
'FULLY_CONNECTED': self.convert_fully_connected,
'GATHER': self.convert_gather,
'GATHER_ND' : self.convert_gather_nd,
'GREATER_EQUAL': self.convert_greater_equal,
'GREATER': self.convert_greater,
'HARD_SWISH': self.convert_hard_swish,
Expand Down Expand Up @@ -1113,6 +1114,31 @@ def convert_gather(self, op):
out = _op.take(data, indices, axis=axis, mode="fast")
return out

def convert_gather_nd(self, op):
"""Method to Convert TFLite GATHER_ND operator"""
try:
from tflite.TensorType import TensorType
except ImportError:
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"

for t in input_tensors:
assert not t.qnn_params, "Quantized input is not expected."

data = self.get_tensor_expr(input_tensors[0])
indices = self.get_tensor_expr(input_tensors[1])

indices_type = input_tensors[1].tensor.Type()
assert indices_type in (TensorType.INT32, TensorType.INT64)

indices_dims = len(_infer_shape(indices))
indices_t = _op.transpose(indices, axes=[-1] + list(range(indices_dims-1)))

out = _op.gather_nd(data, indices_t)
return out

def convert_strided_slice(self, op):
"""Method to Convert TFLite STRIDED_SLICE operator.
NOTE: Eventhough tensorflow supports begin_mask, end_mask, ellipsis_mask, new_axis_mask
Expand Down
31 changes: 31 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,36 @@ def test_forward_gather():
_test_gather((1, 3, 3), [20], 1, 'float32', quantized, oob=True)
_test_gather((1, 3, 3), [20, 20], 2, 'float32', quantized, oob=True)

#######################################################################
# Gather_ND
# ---------

def _test_gather_nd(data, indices):
""" One iteration of GATHER_ND """
with tf.Graph().as_default():
in_data = tf.placeholder(shape=data.shape, dtype=data.dtype, name="data")
indices_data = tf.placeholder(shape=indices.shape, dtype=indices.dtype,
name="indices")
out = tf.gather_nd(in_data, indices_data)

compare_tflite_with_tvm([data, indices], ['data:0', 'indices:0'],
[in_data, indices_data], [out])

def test_forward_gather_nd():
""" GATHER_ND """
_test_gather_nd(
np.array([[[1.2, 2.0], [3.1, 4.1]], [[5.1, 6.1], [7.1, 8.1]]]).astype('float32'),
np.asarray([[0, 1], [1, 0]]).astype('int32')
)
_test_gather_nd(
np.reshape(np.arange(30), [5, 6]).astype('int32'),
np.asarray([[1, 2]]).astype('int32')
)
_test_gather_nd(
np.reshape(np.arange(12), [2, 3, 2]).astype('int32'),
np.asarray([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]).astype('int32')
)

#######################################################################
# StridedSlice
# ------------
Expand Down Expand Up @@ -2217,6 +2247,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_slice()
test_forward_topk()
test_forward_gather()
test_forward_gather_nd()
test_forward_stridedslice()
test_forward_depthtospace()
test_forward_spacetodepth()
Expand Down

0 comments on commit 8a63b7f

Please sign in to comment.