Skip to content

Commit

Permalink
[Relay][Frontend][TFLite] Add parser support for shape and range
Browse files Browse the repository at this point in the history
Signed-off-by: Dhruva Ray <dhruvaray@gmail.com>
  • Loading branch information
dhruvaray committed Jun 1, 2020
1 parent 12cfe4a commit cc7e833
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 24 deletions.
60 changes: 60 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .common import infer_shape as _infer_shape
from .tflite_flexbuffer import FlexBufferDecoder


__all__ = ['from_tflite']

class TensorWrapper(object):
Expand Down Expand Up @@ -113,6 +114,7 @@ def __init__(self, model, subgraph, exp_tab):
'PAD': self.convert_pad,
'POW': self.convert_pow,
'PRELU': self.convert_prelu,
'RANGE': self.convert_range,
'QUANTIZE': self.convert_quantize,
'REDUCE_ANY': self.convert_reduce_any,
'REDUCE_MAX': self.convert_reduce_max,
Expand All @@ -125,6 +127,7 @@ def __init__(self, model, subgraph, exp_tab):
'ROUND': self.convert_round,
'RSQRT': self.convert_rsqrt,
'SELECT': self.convert_select,
'SHAPE': self.convert_shape,
'SIN': self.convert_sin,
'SLICE': self.convert_slice,
'SOFTMAX': self.convert_softmax,
Expand Down Expand Up @@ -607,6 +610,63 @@ def convert_tanh(self, op):

return out

def convert_range(self, op):
"""Convert TFLite Range"""
try:
from tflite.Operator import Operator
from tflite.TensorType import TensorType
except ImportError:
raise ImportError("The tflite package must be installed")

if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized RANGE operator is not supported yet.')

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 3, "input tensors length should be 3"

start, limit, delta = input_tensors[0], input_tensors[1], input_tensors[2]
expressions = []

for t in [start, limit, delta]:
if self.has_expr(t.tensor_idx):
expressions.append(self.get_expr(t.tensor_idx))
else:
tensor_type = self.get_tensor_type_str(t.tensor.Type())
tensor_value = self.get_tensor_value(t)
expressions.append(self.exp_tab.new_const(tensor_value, dtype=tensor_type))

#out type inference
if delta.tensor.Type() == TensorType.FLOAT32:
out_type = self.get_tensor_type_str(delta.tensor.Type())
else:
out_type = self.get_tensor_type_str(start.tensor.Type())

#put type here form op
out = _op.arange(expressions[0], expressions[1], expressions[2], out_type)

return out

def convert_shape(self, op):
"""Convert TFLite Shape"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized SHAPE operator is not supported yet.')

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"

out = _op.shape_of(self.get_expr(input_tensors[0].tensor_idx))

return out

def convert_relu(self, op):
"""Convert TFLite ReLU"""
input_tensors = self.get_input_tensors(op)
Expand Down
170 changes: 146 additions & 24 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,34 @@ def get_real_image_object_detection(im_height, im_width):
data = np.reshape(x, (1, im_height, im_width, 3))
return data

def vmobj_to_list(o):
if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.runtime.container.ADT):
result = []
for f in o:
result.extend(vmobj_to_list(f))
return result
elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
if o.constructor.name_hint == 'Cons':
tl = vmobj_to_list(o.fields[1])
hd = vmobj_to_list(o.fields[0])
hd.extend(tl)
return hd
elif o.constructor.name_hint == 'Nil':
return []
elif 'tensor_nil' in o.constructor.name_hint:
return [0]
elif 'tensor' in o.constructor.name_hint:
return [o.fields[0].asnumpy()]
else:
raise RuntimeError("Unknown object type: %s" %
o.constructor.name_hint)
else:
raise RuntimeError("Unknown object type: %s" % type(o))

def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm',
out_names=None):
out_names=None, mode='graph_runtime'):
""" Generic function to compile on relay and execute on tvm """
# TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
try:
Expand All @@ -109,27 +135,43 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target
shape_dict=shape_dict,
dtype_dict=dtype_dict)

with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(mod, target, params=params)

ctx = tvm.context(target, 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
# set inputs
for i, e in enumerate(input_node):
m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))

m.set_input(**params)
# execute
m.run()
# get outputs
assert out_names is None or num_output == len(out_names), "out_names: {} num_output: {}".format(
out_names, num_output)
tvm_output_list = []
for i in range(0, num_output):
tvm_output = m.get_output(i)
tvm_output_list.append(tvm_output.asnumpy())
return tvm_output_list
if mode in ['debug', 'vm']:
ex = relay.create_executor(mode, mod=mod, ctx=tvm.cpu(), target="llvm")
inputs = []
for param in mod['main'].params:
found = False
for i, n in enumerate(input_node):
if n == param.name_hint:
found = True
inputs.append(tvm.nd.array(input_data[i]))
break
# Interpreter doesn't bind constants, so still need to find in params
if not found:
inputs.append(tvm.nd.array(params[param.name_hint]))
result = ex.evaluate()(*inputs)
return vmobj_to_list(result)
else:
with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(mod, target, params=params)

ctx = tvm.context(target, 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
# set inputs
for i, e in enumerate(input_node):
m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))

m.set_input(**params)
# execute
m.run()
# get outputs
assert out_names is None or num_output == len(out_names), "out_names: {} num_output: {}".format(
out_names, num_output)
tvm_output_list = []
for i in range(0, num_output):
tvm_output = m.get_output(i)
tvm_output_list.append(tvm_output.asnumpy())
return tvm_output_list


def run_tflite_graph(tflite_model_buf, input_data):
Expand Down Expand Up @@ -160,7 +202,7 @@ def run_tflite_graph(tflite_model_buf, input_data):

def compare_tflite_with_tvm(in_data, in_name, input_tensors,
output_tensors, init_global_variables=False,
out_names=None, quantized=False, input_range=None):
out_names=None, quantized=False, input_range=None, mode='graph_runtime'):
"""Generic function to generate and compare TFLite and TVM output"""
in_data = convert_to_list(in_data)
in_name = convert_to_list(in_name)
Expand Down Expand Up @@ -202,7 +244,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
continue

tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device,
num_output=len(out_names), out_names=out_names)
num_output=len(out_names), out_names=out_names,mode=mode)

# WARNING: the results could well be random values clipped to 0 or 255 because of badly tuned output
# range for the specific operator. While adding test ensure that we aren't getting only clipped values
Expand Down Expand Up @@ -860,6 +902,82 @@ def test_all_resize():
_test_resize(tf.image.resize_nearest_neighbor, data, align_corners=False)


#######################################################################
# Range
# -----
def _test_range(start, limit, delta):
# tflite 1.13 convert method does not accept empty shapes
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
tf.reset_default_graph()
with tf.Graph().as_default():
start_scalar, limit_scalar, delta_scalar = \
tf.placeholder(dtype=start.dtype, shape=(), name="start"), \
tf.placeholder(dtype=limit.dtype, shape=(), name="limit"), \
tf.placeholder(dtype=delta.dtype, shape=(), name="delta")

out = tf.range(start_scalar, limit_scalar, delta_scalar, name="range")

compare_tflite_with_tvm(
[start, limit, delta],
["start", "limit", "delta"],
[start_scalar, limit_scalar, delta_scalar],
[out],
mode="vm",
quantized=False
)

def _test_range_default():
# tflite 1.13 convert method does not accept empty shapes
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
tf.reset_default_graph()
with tf.Graph().as_default():

inputs = [
tf.placeholder(dtype=tf.int32, shape=(), name="p1"),
tf.placeholder(dtype=tf.int32, shape=(), name="p2")
]
leaves = [
tf.range(start = inputs[0], limit = inputs[1]), #use default delta
tf.range(start = inputs[1]) #use start as limit with 0 as the first item in the range
]

compare_tflite_with_tvm(
[np.int32(1), np.int32(18)],
["p1", "p2"],
inputs,
leaves,
mode="vm",
quantized=False
)

def test_forward_range():
_test_range(np.int32(1), np.int32(18), np.int32(3))
_test_range(np.int32(1), np.int32(18), np.float32(3.1)) # increment is of type float
_test_range(np.float32(1.0), np.int32(18), np.int32(3.1)) # start is of type float
_test_range_default()

#######################################################################
# Shape
# -----
def test_forward_shape():
# tflite 1.13 convert method does not accept empty shapes
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
tf.reset_default_graph()
with tf.Graph().as_default():
data = np.array([1, 18, 3], dtype=np.int32)
start = tf.placeholder(dtype=tf.int32, shape=[], name="start")
limit = tf.placeholder(dtype=tf.int32, shape=[], name="limit")
delta = tf.placeholder(dtype=tf.int32, shape=[], name="delta")
r = tf.range(start, limit, delta, tf.int32, name="range")
out = tf.shape(r, out_type=tf.dtypes.int32)
compare_tflite_with_tvm(
[x for x in np.nditer(data)],
["start", "limit", "delta"],
[start, limit, delta],
[out],
mode="vm",
quantized=False
)
#######################################################################
# Concatenation
# -------------
Expand Down Expand Up @@ -2290,13 +2408,17 @@ def test_forward_mediapipe_hand_landmark():
# Tile
test_forward_tile()

# Query
test_forward_shape()

# Transforms
test_forward_concatenation()
test_forward_pad()
test_forward_pack()
test_forward_unpack()
test_forward_reshape()
test_all_resize()
test_forward_range()
test_forward_squeeze()
test_forward_slice()
test_forward_topk()
Expand Down

0 comments on commit cc7e833

Please sign in to comment.