Skip to content

Commit ca92d17

Browse files
inadobzhiics
authored andcommitted
[Relay][Frontend][TFLite] Add parser support for logical operators (apache#4642)
* [Relay][Frontend][TFLite] Add parser support for logical operators * Add parser support for logical_and, logical_or * Add boolean dtype as a valid tensor type * BOOLEAN dtype is supported only from tf 1.15 so logical ops work only in that and newer versions * Logical_not is ommited since tflite can't convert it --> throws errors for addv2 * Add TFLite vesion check in tests for logical ops * Check is added because of boolean dtype lack of support
1 parent cbf425b commit ca92d17

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ def __init__(self, model, subgraph, exp_tab):
117117
'PRELU': self.convert_prelu,
118118
'TRANSPOSE_CONV': self.convert_transpose_conv,
119119
'SQUARED_DIFFERENCE': self.convert_squared_difference,
120+
'LOGICAL_AND': self.convert_logical_and,
121+
'LOGICAL_OR': self.convert_logical_or,
120122
}
121123

122124
def check_unsupported_ops(self):
@@ -222,6 +224,9 @@ def get_tensor_value(self, tensor_wrapper):
222224
if tensor_wrapper.tensor.Type() == TensorType.INT64:
223225
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int64).reshape(
224226
tensor_wrapper.tensor.ShapeAsNumpy())
227+
if tensor_wrapper.tensor.Type() == TensorType.BOOL:
228+
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.bool_).reshape(
229+
tensor_wrapper.tensor.ShapeAsNumpy())
225230
raise NotImplementedError("Tensor type {} is currently not supported"
226231
.format(str(tensor_wrapper.tensor.Type())))
227232

@@ -240,6 +245,8 @@ def get_tensor_type_str(self, tensor_type):
240245
return "int32"
241246
if tensor_type == TensorType.INT64:
242247
return "int64"
248+
if tensor_type == TensorType.BOOL:
249+
return "bool"
243250
raise NotImplementedError("Tensor type {} is currently not supported"
244251
.format(str(tensor_type)))
245252

@@ -792,6 +799,33 @@ def convert_not_equal(self, op):
792799
'TFlite quantized NOT_EQUAL operator is not supported yet.')
793800
return self._convert_elemwise(_op.not_equal, op)
794801

802+
def _convert_logical_binary(self, relay_op, op):
803+
"""Generic method to convert logical binary ops"""
804+
try:
805+
from tflite.Operator import Operator
806+
except ImportError:
807+
raise ImportError("The tflite package must be installed")
808+
809+
assert isinstance(op, Operator)
810+
input_tensors = self.get_input_tensors(op)
811+
assert len(input_tensors) == 2, "input tensors length should be 2"
812+
813+
lhs_tensor = input_tensors[0]
814+
lhs_expr = self.get_expr(lhs_tensor.tensor_idx)
815+
rhs_tensor = input_tensors[1]
816+
rhs_expr = self.get_expr(rhs_tensor.tensor_idx)
817+
out = relay_op(lhs_expr, rhs_expr)
818+
819+
return out
820+
821+
def convert_logical_and(self, op):
822+
"""Convert tflite LOGICAL_AND"""
823+
return self._convert_logical_binary(_op.logical_and, op)
824+
825+
def convert_logical_or(self, op):
826+
"""Convert tflite LOGICAL_OR"""
827+
return self._convert_logical_binary(_op.logical_or, op)
828+
795829
def convert_zeros_like(self, op):
796830
"""Convert TFLite ZEROS LIKE"""
797831
try:

tests/python/frontend/tflite/test_forward.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,34 @@ def test_all_elemwise():
965965
_test_forward_elemwise(_test_equal)
966966
_test_forward_elemwise(_test_not_equal)
967967

968+
#######################################################################
969+
# Logical operators
970+
# -----------------
971+
972+
def _test_logical_binary(logical_bin_op, data):
973+
974+
with tf.Graph().as_default():
975+
in_data = [array_ops.placeholder(shape=data[0].shape, dtype='bool', name='in_0'),
976+
array_ops.placeholder(shape=data[1].shape, dtype='bool', name='in_1')]
977+
out = logical_bin_op(in_data[0], in_data[1], name='out')
978+
compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out])
979+
980+
def _test_forward_logical_and(data):
981+
""" One iteration of logical and """
982+
return _test_logical_binary(math_ops.logical_and, data)
983+
984+
def _test_forward_logical_or(data):
985+
""" One iteration of logical or """
986+
return _test_logical_binary(math_ops.logical_or, data)
987+
988+
def test_all_logical():
989+
data = [np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool'),
990+
np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool')]
991+
# boolean dtype is not supported by older versions than TFLite 1.15.0
992+
if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
993+
_test_forward_logical_and(data)
994+
_test_forward_logical_or(data)
995+
968996
#######################################################################
969997
# Zeros like
970998
# --------
@@ -1530,6 +1558,9 @@ def test_forward_mediapipe_hand_landmark():
15301558
# Reduce
15311559
test_all_reduce()
15321560

1561+
# Logical
1562+
test_all_logical()
1563+
15331564
# End to End
15341565
test_forward_mobilenet_v1()
15351566
test_forward_mobilenet_v2()

0 commit comments

Comments
 (0)