Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QNN][TFLite] Parsing TFLite quantized models. #3900

Merged
merged 1 commit into from
Oct 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 159 additions & 16 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .. import expr as _expr
from .. import module as _module
from .. import op as _op
from .. import qnn as _qnn
from ... import nd as _nd
from .common import ExprTable
from .common import infer_shape as _infer_shape
Expand All @@ -32,10 +33,11 @@

class TensorWrapper(object):
"""Tensor wrapper for TFLite Tensor"""
def __init__(self, tensor_idx, tensor, buffer):
def __init__(self, tensor_idx, tensor, buffer, qnn_params=None):
self.tensor_idx = tensor_idx
self.tensor = tensor
self.buffer = buffer
self.qnn_params = qnn_params

class OperatorConverter(object):
"""Operator Converted for converting TFLite ops to Relay ops"""
Expand Down Expand Up @@ -160,7 +162,19 @@ def get_tensors(self, tensors_idx_list):
tensor = self.subgraph.Tensors(tensor_idx)
buffer_idx = tensor.Buffer()
buffer = self.model.Buffers(buffer_idx)
return_list.append(TensorWrapper(tensor_idx, tensor, buffer))

# Check if the tensors are quantized. Parse if yes.
qnn_params = None
tflite_qnn_params = tensor.Quantization()
if tflite_qnn_params is not None:
scale = float(tflite_qnn_params.ScaleAsNumpy())
zero_point = int(tflite_qnn_params.ZeroPointAsNumpy())
# Check that the scale and zero points are valid.
if scale != 0 or zero_point != 0:
qnn_params = dict()
qnn_params['scale'] = scale
qnn_params['zero_point'] = zero_point
return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params))
return return_list

def get_tensor_value(self, tensor_wrapper):
Expand Down Expand Up @@ -200,6 +214,10 @@ def get_tensor_type_str(self, tensor_type):
raise NotImplementedError("Tensor type {} is currently not supported"
.format(str(tensor_type)))

def has_same_qnn_params(self, lhs_tensor, rhs_tensor):
return lhs_tensor.qnn_params['scale'] == rhs_tensor.qnn_params['scale'] and \
lhs_tensor.qnn_params['zero_point'] == rhs_tensor.qnn_params['zero_point']

def convert_conv2d(self, op):
"""Convert TFLite conv2d"""
return self.convert_conv(op, "conv2d")
Expand Down Expand Up @@ -238,8 +256,15 @@ def convert_reshape(self, op):
target_shape = reshape_options.NewShapeAsNumpy()

in_expr = self.get_expr(input_tensor_idx)
out = _op.reshape(in_expr, newshape=tuple(target_shape))

# If the tensors are quantized, ensure that input/output qnn params are same.
if input_tensor.qnn_params:
output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "There should be only 1 output tensor"
output_tensor = output_tensors[0]
assert self.has_same_qnn_params(input_tensor, output_tensor), \
"TFLite reshape requires input and output scale and zero points to be equal"
out = _op.reshape(in_expr, newshape=tuple(target_shape))
return out

def _convert_resize(self, method, op):
Expand Down Expand Up @@ -324,10 +349,33 @@ def convert_softmax(self, op):

input_tensor = input_tensors[0]
input_tensor_idx = input_tensor.tensor_idx

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]
output_tensor_type = output_tensor.tensor.Type()
output_tensor_type_str = self.get_tensor_type_str(output_tensor_type)

params = {'axis': 1} # 1 is channel
in_expr = self.get_expr(input_tensor_idx)

# TODO - Naive softmax int8 implementation leads to bad accuracy. Currently, we can
# dequantize to FP32 and perform softmax on FP32. We can investigate an integer only softmax
# implementation in future.
if input_tensor.qnn_params:
in_expr = _qnn.op.dequantize(data=in_expr,
input_scale=input_tensor.qnn_params['scale'],
input_zero_point=input_tensor.qnn_params['zero_point'])

out = _op.nn.softmax(in_expr, **params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems generates float output for softmax, would it be better if we quantize it back to uint8 since the semantic of a uint8 model requires an uint8 output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would indeed be better. Doing that.


# Go back to integer dataype if the original operator was quantized.
if output_tensor.qnn_params:
out = _qnn.op.quantize(data=out,
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str)

return out

def convert_tanh(self, op):
Expand Down Expand Up @@ -380,7 +428,8 @@ def convert_concatenation(self, op):
in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors]

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors should be 1"
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]

assert op.BuiltinOptionsType() == BuiltinOptions.ConcatenationOptions
op_options = op.BuiltinOptions()
Expand All @@ -389,12 +438,27 @@ def convert_concatenation(self, op):
concatenation_axis = concatenation_options.Axis()
fused_activation_fn = concatenation_options.FusedActivationFunction()

# with axis in N H W C
out = _op.concatenate(in_exprs, axis=concatenation_axis)
if not input_tensors[0].qnn_params:
out = _op.concatenate(in_exprs, axis=concatenation_axis)
else:
input_scales = [input_tensor.qnn_params['scale'] for input_tensor in input_tensors]
input_zero_points = \
[input_tensor.qnn_params['zero_point'] for input_tensor in input_tensors]
out = _qnn.op.concatenate(in_exprs,
input_scales=input_scales,
input_zero_points=input_zero_points,
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
axis=concatenation_axis)

# if we have activation fn
if fused_activation_fn != ActivationFunctionType.NONE:
out = self.convert_fused_activation_function(out, fused_activation_fn)
if not output_tensor.qnn_params:
out = self.convert_fused_activation_function(out, fused_activation_fn)
else:
raise tvm.error.OpNotImplemented(
'Operator {} with fused activation is not supported yet.'
.format('qnn.op.concatenate'))
Comment on lines +459 to +461
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we can have an clamp here to support it...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to handle different types of activations like Relu/Relu6. I have a draft implementation of activation handling. I will send a separate PR for that once I test it with few cases.

return out

def _convert_elemwise(self, relay_op, op):
Expand Down Expand Up @@ -557,6 +621,12 @@ def convert_fully_connected(self, op):
input_tensor_idx = input_tensor.tensor_idx
weight_tensor = input_tensors[1]

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]
output_tensor_type = output_tensor.tensor.Type()
output_tensor_type_str = self.get_tensor_type_str(output_tensor_type)

input_tensor_shape = input_tensor.tensor.ShapeAsNumpy()
weight_tensor_shape = weight_tensor.tensor.ShapeAsNumpy()

Expand Down Expand Up @@ -584,7 +654,13 @@ def convert_fully_connected(self, op):
weight_value = self.get_tensor_value(weight_tensor)
weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)

out = _op.nn.dense(in_expr, weight_expr)
if input_tensor.qnn_params:
out = _qnn.op.dense(in_expr, weight_expr,
input_zero_point=input_tensor.qnn_params['zero_point'],
kernel_zero_point=weight_tensor.qnn_params['zero_point'],
out_dtype='int32')
else:
out = _op.nn.dense(in_expr, weight_expr)

# if we have bias
if len(input_tensors) == 3:
Expand All @@ -599,7 +675,23 @@ def convert_fully_connected(self, op):

# If we have fused activations
if fused_activation_fn != ActivationFunctionType.NONE:
out = self.convert_fused_activation_function(out, fused_activation_fn)
if not output_tensor.qnn_params:
out = self.convert_fused_activation_function(out, fused_activation_fn)
else:
raise tvm.error.OpNotImplemented(
'Operator {} with fused activation is not supported yet.'
.format('qnn.op.dense'))

# Finally if the dense is quantized. Add a requantize at the end.
if output_tensor.qnn_params:
input_scale = input_tensor.qnn_params['scale'] * weight_tensor.qnn_params['scale']
input_zero_point = 0
out = _qnn.op.requantize(out,
input_scale=input_scale,
input_zero_point=input_zero_point,
anijain2305 marked this conversation as resolved.
Show resolved Hide resolved
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str)

return out

Expand Down Expand Up @@ -671,6 +763,12 @@ def convert_conv(self, op, conv_type):
input_tensor_idx = input_tensor.tensor_idx
weight_tensor = input_tensors[1]

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]
output_tensor_type = output_tensor.tensor.Type()
output_tensor_type_str = self.get_tensor_type_str(output_tensor_type)

is_depthwise_conv = False
if conv_type == 'conv2d':
assert op.BuiltinOptionsType() == BuiltinOptions.Conv2DOptions
Expand Down Expand Up @@ -758,7 +856,14 @@ def convert_conv(self, op, conv_type):
raise tvm.error.OpAttributeUnImplemented(
'Padding format {} is not supported for operator Conv.'.format(padding))

out = _op.nn.conv2d(data=in_expr, weight=weight_expr, **params)
if input_tensor.qnn_params:
qnn_conv2d_params = dict(params)
qnn_conv2d_params['input_zero_point'] = input_tensor.qnn_params['zero_point']
qnn_conv2d_params['kernel_zero_point'] = weight_tensor.qnn_params['zero_point']
qnn_conv2d_params['out_dtype'] = 'int32'
out = _qnn.op.conv2d(in_expr, weight_expr, **qnn_conv2d_params)
else:
out = _op.nn.conv2d(in_expr, weight_expr, **params)

# if we have bias
if len(input_tensors) == 3:
Expand All @@ -774,7 +879,23 @@ def convert_conv(self, op, conv_type):

# If we have fused activations
if fused_activation_fn != ActivationFunctionType.NONE:
out = self.convert_fused_activation_function(out, fused_activation_fn)
if not output_tensor.qnn_params:
out = self.convert_fused_activation_function(out, fused_activation_fn)
else:
raise tvm.error.OpNotImplemented(
'Operator {} with fused activation is not supported yet.'
.format('qnn.op.conv2d'))
Comment on lines +885 to +887
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's interesting, AFAIK, TFLite model converter always fuses the ReLU family activations to Conv/FC operators. The activation of Inception model you are using is not fused?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, TFLite hosted models do not have any fused activation. There is a fused activation in Object Detection models, but we are far away from supporting object detection (basically dependent on control flow support).


# Finally if the conv is quantized. Add a requantize at the end.
if output_tensor.qnn_params:
input_scale = input_tensor.qnn_params['scale'] * weight_tensor.qnn_params['scale']
input_zero_point = 0
out = _qnn.op.requantize(out,
input_scale=input_scale,
input_zero_point=input_zero_point,
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str)
Comment on lines +890 to +898
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can have a small function, requantize_output(out, output_tensor, input_tensor) to wrapper this :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion. But, I will skip this one for now. Currently, there are only 2 uses of requantize. Both of which are quite specific - where we have to multiply the input scales to get the new input scale. As we add more operators, and thus more requantize, I will revisit this.


return out

Expand Down Expand Up @@ -879,6 +1000,12 @@ def convert_pool2d(self, op, pool_type):
input_tensor = input_tensors[0]
input_tensor_idx = input_tensor.tensor_idx

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors should be 1"
output_tensor = output_tensors[0]
output_tensor_type = output_tensor.tensor.Type()
output_tensor_type_str = self.get_tensor_type_str(output_tensor_type)

assert op.BuiltinOptionsType() == BuiltinOptions.Pool2DOptions
op_options = op.BuiltinOptions()
pool2d_options = Pool2DOptions()
Expand Down Expand Up @@ -909,17 +1036,32 @@ def convert_pool2d(self, op, pool_type):
'Padding format {} for operator Pool2D is not supported.'.format(padding))

if pool_type == "average":
out = _op.nn.avg_pool2d(in_expr, **params)
if input_tensor.qnn_params:
assert self.has_same_qnn_params(input_tensor, output_tensor), \
'TFLite avg_pool2dreshape requires input and output scale' \
'and zero points to be equal'
out = _op.cast(in_expr, dtype="int32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style consistency of the "/' :)

out = _op.nn.avg_pool2d(out, **params)
out = _op.cast(out, dtype=output_tensor_type_str)
else:
out = _op.nn.avg_pool2d(in_expr, **params)
elif pool_type == "max":
if input_tensor.qnn_params:
assert self.has_same_qnn_params(input_tensor, output_tensor), \
"qnn.op.max_pool2d requires input and output qnn params to be same"
out = _op.nn.max_pool2d(in_expr, **params)
else:
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend TFLite.'.format(pool_type + ' pool'))

# If we have fused activations
if fused_activation_fn != ActivationFunctionType.NONE:
out = self.convert_fused_activation_function(out, fused_activation_fn)

if input_tensor.qnn_params:
raise tvm.error.OpNotImplemented(
'Operator {} with fused activation is not supported yet.'
.format('qnn.op.pool2d'))
else:
out = self.convert_fused_activation_function(out, fused_activation_fn)
return out

def convert_pad(self, op):
Expand Down Expand Up @@ -962,7 +1104,7 @@ def convert_pack(self, op):
in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors]

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors should be 1"
assert len(output_tensors) == 1, "output tensors length should be 1"

assert op.BuiltinOptionsType() == BuiltinOptions.PackOptions
op_options = op.BuiltinOptions()
Expand Down Expand Up @@ -1210,4 +1352,5 @@ def from_tflite(model, shape_dict, dtype_dict):
outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
func = _expr.Function(analysis.free_vars(outputs), outputs)
return _module.Module.from_expr(func), params
mod = _module.Module.from_expr(func)
return mod, params
45 changes: 45 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,46 @@ def test_forward_inception_v4_net():
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)

def test_forward_qnn_inception_v1_net():
"""Test the Quantized TFLite Inception model."""
# InceptionV1
tflite_model_file = tf_testing.get_workload_official(
"https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_224_quant_20181026.tgz",
"inception_v1_224_quant.tflite")
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
# Checking the labels because the requantize implementation is different between TFLite and
# Relay. This cause final output numbers to mismatch. So, testing accuracy via labels.
np.random.seed(0)
data = np.random.random_integers(low=0, high=128, size=(1, 224, 224, 3)).astype('uint8')
tflite_output = run_tflite_graph(tflite_model_buf, data)
tflite_predictions = np.squeeze(tflite_output)
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
tvm_predictions = np.squeeze(tvm_output)
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)

def test_forward_qnn_mobilenet_v1_net():
"""Test the Quantized TFLite Mobilenet V1 model."""
# MobilenetV1
tflite_model_file = tf_testing.get_workload_official(
"https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
"mobilenet_v1_1.0_224_quant.tflite")
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
# Checking the labels because the requantize implementation is different between TFLite and
# Relay. This cause final output numbers to mismatch. So, testing accuracy via labels.
np.random.seed(0)
data = np.random.random_integers(low=0, high=128, size=(1, 224, 224, 3)).astype('uint8')
tflite_output = run_tflite_graph(tflite_model_buf, data)
tflite_predictions = np.squeeze(tflite_output)
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
tvm_predictions = np.squeeze(tvm_output)
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain more why we can not compare tflite_out with tvm_out directly like we do in FP32? I think we could get the same result if we have kept the compatibility with TFLite, if not, why we don’t keep the compatibility?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only difference between TFLite and QNN is rounding. TFLite reference implementation is reverse-engineered from ARM assembly instructions (which perform 2 roundings in general). QNN follows more formal rounding described here. If you are interested, I would encourage to read the inlined comments in the code - https://github.com/dmlc/tvm/blob/068c148f01fcc81c72004512de320ca1cee24dc6/src/relay/qnn/util.cc#L79-L133

There was also a fair amount of discussion which has been summarized here - #3591 (comment). This should also be helpful.

This rounding can make small variations in the output, because of which we can not do exact check.

I have measured accuracy on TFLite quantized Inception V1-V3 on 50K images using QNN, and observed accuracy parity with TFLite accuracy native execution.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the information! @anijain2305 However, if I remember correctly, when I review the code of requantize, I mention that we could use round_away_from_zero, which is used for TFLite. Do we change it later?

Alright, if we change the default value later, In my opinion, I think we should pass round_away_from_zero to requantize operator for TFLite model frontend parsing.The reasons are:

  • Keep the compatibility with TFLite should be the first priority. Because we are parsing TFLite model, the quantization accuracy is controlled by Tensorflow quantization-aware training and TFLite.
  • When we run new quantization model of TFLite, we can not get test suites all the time. For example, the customers / algo. team sends us the TFLite model and requires us boost the performance. We should make sure the result is the same as TFLite. The simplest way is to compare the result with TFLite, not do end-to-end accuracy, sometimes we even can not do it.
  • For our development, we could also do the simple compare with TFLite like FP32, not do end-to-end accuracy. If one op need to requantize internally, we could compare it with TFLite, not need to think about the result is right or not using math when we have different result.

Copy link
Contributor Author

@anijain2305 anijain2305 Oct 13, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TFlite rounding is not exactly formal round_away_from_zero. You might already know but these are the two functions that are involved in TFLite rounding

  using gemmlowp::RoundingDivideByPOT;
  using gemmlowp::SaturatingRoundingDoublingHighMul;

More at - https://github.com/tensorflow/tensorflow/blob/ff91cd691027076e6128afbd902d3d38e3672787/tensorflow/lite/kernels/internal/common.h#L105

The key idea is that the above two functions map to ARM assembly instructions. But, in this process, they perform approximation that make it little different from formal GNU C rounding. The reasoning behind doing that should be to get performance.

In QNN, we decided to follow GNU C rounding, as thats more formal and can work across more frameworks (like MxNet, PyTorch and not just TFLite).

However, if one wants, one can implement a TFLite rounding. The code is modular enough to add one more rounding option. I think you have good comfortability with TFLite, so I would encourage you to write TFLite rounding if you get time :)

Copy link
Member

@FrozenGene FrozenGene Oct 14, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fully understand your point across more frameworks. My previous point is to emphasize we could provide one option for our TFLite frontend parser so that we could keep the compatibility with TFLite, we don't need to change any default value in our existing code, which could be used for MXNet, PyTorch and so on. Otherwise, the users will confuse why our result is not the same as TFLite when they use TVM. Keep the compatibility with TFLite should be the first priority when we parse TFLite model. So I think we shouldn't leave this to implement by TVM users.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, as you said, it doesn't affect the code structure. If you agree the point of compatibility point, I think we could make this PR in firstly and open one issue reminding us we should do this thing. If you don't have spare time, I could help to manage it next.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can open an issue to track the TFLite rounding. It would not affect code structure. We can set the value of the rounding in TFLite parser, hiding it from the TVM users.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Could you help to open this this to track it? And you could assign this task to me.


#######################################################################
# SSD Mobilenet
# -------------
Expand Down Expand Up @@ -1048,3 +1088,8 @@ def test_forward_ssd_mobilenet_v1():
test_forward_inception_v3_net()
test_forward_inception_v4_net()
test_forward_ssd_mobilenet_v1()

# End to End quantized
# TODO - MobilenetV2 fails for now. Remove when fixed.
test_forward_qnn_inception_v1_net()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us add Mobilenet V2 quant model if this PR could run it. Because it is very popular and often compared across different framework.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Till now I have been focussing on the server side, and thus I used Inception. Can this PR go in first and I can send a PR for mobilenet V2 separately?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no problem. However, I have one question that do we have any TODO for Mobilenet V2 quant model? From my knowledge, we should work well on Mobilenet V2 quant model using this PR and just add test code of Mobilenet V2 quant model.

Copy link
Contributor Author

@anijain2305 anijain2305 Oct 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is. I want to do a full 50k image test before adding MobileNetV2. I have done that for InceptionV1 and ensured that its accuracy is same as that provided in TFLite info. I will setup the pre-processing for mobilenet next week, and check end-to-end accuracy, and send a separate PR. What do you say?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need to do it you mention. Because we are parsing TFLite quant model and TFLite framework is one baseline. We could treat TFLite quant model like TFLite FP32 model and compare the result between us and TFLite framework. They are no difference. I think we only need to do end-to-end accuracy when we do quantization in TVM inside or we have strong reason to implement ops but don't keep compatibility with TFLite framework.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added MobileNet V1. MobilenetV2 fails with a segfault in LLVM. Will debug.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the detail information of LLVM report? Try the latest?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the whole stack trace. Happens while the graph runtime create. Relay build has already passed.

0x00007fff78dd1e64 in llvm::EVT::getExtendedVectorNumElements() const () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
(gdb) bt
#0  0x00007fff78dd1e64 in llvm::EVT::getExtendedVectorNumElements() const () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#1  0x00007fff78ac1090 in llvm::TargetLowering::SimplifyDemandedBits(llvm::SDValue, llvm::APInt const&, llvm::APInt const&, llvm::KnownBits&, llvm::TargetLowering::TargetLoweringOpt&, unsigned int, bool) const () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#2  0x00007fff78ac0318 in llvm::TargetLowering::SimplifyDemandedBits(llvm::SDValue, llvm::APInt const&, llvm::APInt const&, llvm::KnownBits&, llvm::TargetLowering::TargetLoweringOpt&, unsigned int, bool) const () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#3  0x00007fff78abe55f in llvm::TargetLowering::SimplifyDemandedBits(llvm::SDValue, llvm::APInt const&, llvm::KnownBits&, llvm::TargetLowering::TargetLoweringOpt&, unsigned int, bool) const ()
   from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#4  0x00007fff789c7b65 in (anonymous namespace)::DAGCombiner::SimplifyDemandedBits(llvm::SDValue, llvm::APInt const&) () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#5  0x00007fff789a4339 in (anonymous namespace)::DAGCombiner::visit(llvm::SDNode*) () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#6  0x00007fff7897670c in (anonymous namespace)::DAGCombiner::combine(llvm::SDNode*) () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#7  0x00007fff78975cd3 in llvm::SelectionDAG::Combine(llvm::CombineLevel, llvm::AAResults*, llvm::CodeGenOpt::Level) () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#8  0x00007fff78aaa7b2 in llvm::SelectionDAGISel::CodeGenAndEmitDAG() () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#9  0x00007fff78aa9939 in llvm::SelectionDAGISel::SelectAllBasicBlocks(llvm::Function const&) () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#10 0x00007fff78aa6c06 in llvm::SelectionDAGISel::runOnMachineFunction(llvm::MachineFunction&) () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#11 0x00007fff787c2ffe in (anonymous namespace)::X86DAGToDAGISel::runOnMachineFunction(llvm::MachineFunction&) () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#12 0x00007fff78cb0ff4 in llvm::MachineFunctionPass::runOnFunction(llvm::Function&) () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#13 0x00007fff7976c4ba in llvm::FPPassManager::runOnFunction(llvm::Function&) () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#14 0x00007fff7976c843 in llvm::FPPassManager::runOnModule(llvm::Module&) () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#15 0x00007fff7976cdbf in llvm::legacy::PassManagerImpl::run(llvm::Module&) () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#16 0x00007fff78bdee38 in llvm::MCJIT::emitObject(llvm::Module*) () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#17 0x00007fff78bdf094 in llvm::MCJIT::generateCodeForModule(llvm::Module*) () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#18 0x00007fff78be038e in llvm::MCJIT::findSymbol(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, bool) () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#19 0x00007fff78bdfe58 in llvm::MCJIT::getSymbolAddress(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, bool) ()
   from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#20 0x00007fff78be04ba in llvm::MCJIT::getGlobalValueAddress(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) ()
   from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#21 0x00007fff77dcb794 in tvm::codegen::LLVMModuleNode::LazyInitJIT() () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#22 0x00007fff77dca02e in tvm::codegen::LLVMModuleNode::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::shared_ptr<tvm::runtime::ModuleNode> const&) () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#23 0x00007fff77e0db68 in tvm::runtime::GraphRuntime::CreateTVMOp(tvm::runtime::TVMOpParam const&, std::vector<DLTensor, std::allocator<DLTensor> > const&, unsigned long) ()
   from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#24 0x00007fff77e0afe4 in tvm::runtime::GraphRuntime::SetupOpExecs() () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#25 0x00007fff77e09aad in tvm::runtime::GraphRuntime::Init(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::Module, std::vector<DLContext, std::allocator<DLContext> > const&) () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#26 0x00007fff77e0f10b in tvm::runtime::GraphRuntimeCreate(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::Module const&, std::vector<DLContext, std::allocator<DLContext> > const&) () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#27 0x00007fff77e10f8b in std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::$_12>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#28 0x00007fff77de5577 in TVMFuncCall () from /home/ubuntu/workplace/tvm/t1/tvm/build/libtvm.so
#29 0x00007fff7e54ee20 in ffi_call_unix64 () from /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so
#30 0x00007fff7e54e88b in ffi_call () from /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so
#31 0x00007fff7e54901a in _ctypes_callproc () from /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so
#32 0x00007fff7e53cfcb in ?? () from /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so
#33 0x00000000005c3bd7 in PyObject_Call ()
#34 0x00000000005354af in PyEval_EvalFrameEx ()
#35 0x000000000053a81b in PyEval_EvalCodeEx ()
#36 0x00000000004e3423 in ?? ()
#37 0x00000000005c3bd7 in PyObject_Call ()
#38 0x00000000004f08be in ?? ()
#39 0x00000000005c3bd7 in PyObject_Call ()

test_forward_qnn_mobilenet_v1_net()