-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[QNN][TFLite] Parsing TFLite quantized models. #3900
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
from .. import expr as _expr | ||
from .. import module as _module | ||
from .. import op as _op | ||
from .. import qnn as _qnn | ||
from ... import nd as _nd | ||
from .common import ExprTable | ||
from .common import infer_shape as _infer_shape | ||
|
@@ -32,10 +33,11 @@ | |
|
||
class TensorWrapper(object): | ||
"""Tensor wrapper for TFLite Tensor""" | ||
def __init__(self, tensor_idx, tensor, buffer): | ||
def __init__(self, tensor_idx, tensor, buffer, qnn_params=None): | ||
self.tensor_idx = tensor_idx | ||
self.tensor = tensor | ||
self.buffer = buffer | ||
self.qnn_params = qnn_params | ||
|
||
class OperatorConverter(object): | ||
"""Operator Converted for converting TFLite ops to Relay ops""" | ||
|
@@ -160,7 +162,19 @@ def get_tensors(self, tensors_idx_list): | |
tensor = self.subgraph.Tensors(tensor_idx) | ||
buffer_idx = tensor.Buffer() | ||
buffer = self.model.Buffers(buffer_idx) | ||
return_list.append(TensorWrapper(tensor_idx, tensor, buffer)) | ||
|
||
# Check if the tensors are quantized. Parse if yes. | ||
qnn_params = None | ||
tflite_qnn_params = tensor.Quantization() | ||
if tflite_qnn_params is not None: | ||
scale = float(tflite_qnn_params.ScaleAsNumpy()) | ||
zero_point = int(tflite_qnn_params.ZeroPointAsNumpy()) | ||
# Check that the scale and zero points are valid. | ||
if scale != 0 or zero_point != 0: | ||
qnn_params = dict() | ||
qnn_params['scale'] = scale | ||
qnn_params['zero_point'] = zero_point | ||
return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params)) | ||
return return_list | ||
|
||
def get_tensor_value(self, tensor_wrapper): | ||
|
@@ -200,6 +214,10 @@ def get_tensor_type_str(self, tensor_type): | |
raise NotImplementedError("Tensor type {} is currently not supported" | ||
.format(str(tensor_type))) | ||
|
||
def has_same_qnn_params(self, lhs_tensor, rhs_tensor): | ||
return lhs_tensor.qnn_params['scale'] == rhs_tensor.qnn_params['scale'] and \ | ||
lhs_tensor.qnn_params['zero_point'] == rhs_tensor.qnn_params['zero_point'] | ||
|
||
def convert_conv2d(self, op): | ||
"""Convert TFLite conv2d""" | ||
return self.convert_conv(op, "conv2d") | ||
|
@@ -238,8 +256,15 @@ def convert_reshape(self, op): | |
target_shape = reshape_options.NewShapeAsNumpy() | ||
|
||
in_expr = self.get_expr(input_tensor_idx) | ||
out = _op.reshape(in_expr, newshape=tuple(target_shape)) | ||
|
||
# If the tensors are quantized, ensure that input/output qnn params are same. | ||
if input_tensor.qnn_params: | ||
output_tensors = self.get_output_tensors(op) | ||
assert len(output_tensors) == 1, "There should be only 1 output tensor" | ||
output_tensor = output_tensors[0] | ||
assert self.has_same_qnn_params(input_tensor, output_tensor), \ | ||
"TFLite reshape requires input and output scale and zero points to be equal" | ||
out = _op.reshape(in_expr, newshape=tuple(target_shape)) | ||
return out | ||
|
||
def _convert_resize(self, method, op): | ||
|
@@ -324,10 +349,33 @@ def convert_softmax(self, op): | |
|
||
input_tensor = input_tensors[0] | ||
input_tensor_idx = input_tensor.tensor_idx | ||
|
||
output_tensors = self.get_output_tensors(op) | ||
assert len(output_tensors) == 1, "output tensors length should be 1" | ||
output_tensor = output_tensors[0] | ||
output_tensor_type = output_tensor.tensor.Type() | ||
output_tensor_type_str = self.get_tensor_type_str(output_tensor_type) | ||
|
||
params = {'axis': 1} # 1 is channel | ||
in_expr = self.get_expr(input_tensor_idx) | ||
|
||
# TODO - Naive softmax int8 implementation leads to bad accuracy. Currently, we can | ||
# dequantize to FP32 and perform softmax on FP32. We can investigate an integer only softmax | ||
# implementation in future. | ||
if input_tensor.qnn_params: | ||
in_expr = _qnn.op.dequantize(data=in_expr, | ||
input_scale=input_tensor.qnn_params['scale'], | ||
input_zero_point=input_tensor.qnn_params['zero_point']) | ||
|
||
out = _op.nn.softmax(in_expr, **params) | ||
|
||
# Go back to integer dataype if the original operator was quantized. | ||
if output_tensor.qnn_params: | ||
out = _qnn.op.quantize(data=out, | ||
output_scale=output_tensor.qnn_params['scale'], | ||
output_zero_point=output_tensor.qnn_params['zero_point'], | ||
out_dtype=output_tensor_type_str) | ||
|
||
return out | ||
|
||
def convert_tanh(self, op): | ||
|
@@ -380,7 +428,8 @@ def convert_concatenation(self, op): | |
in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors] | ||
|
||
output_tensors = self.get_output_tensors(op) | ||
assert len(output_tensors) == 1, "output tensors should be 1" | ||
assert len(output_tensors) == 1, "output tensors length should be 1" | ||
output_tensor = output_tensors[0] | ||
|
||
assert op.BuiltinOptionsType() == BuiltinOptions.ConcatenationOptions | ||
op_options = op.BuiltinOptions() | ||
|
@@ -389,12 +438,27 @@ def convert_concatenation(self, op): | |
concatenation_axis = concatenation_options.Axis() | ||
fused_activation_fn = concatenation_options.FusedActivationFunction() | ||
|
||
# with axis in N H W C | ||
out = _op.concatenate(in_exprs, axis=concatenation_axis) | ||
if not input_tensors[0].qnn_params: | ||
out = _op.concatenate(in_exprs, axis=concatenation_axis) | ||
else: | ||
input_scales = [input_tensor.qnn_params['scale'] for input_tensor in input_tensors] | ||
input_zero_points = \ | ||
[input_tensor.qnn_params['zero_point'] for input_tensor in input_tensors] | ||
out = _qnn.op.concatenate(in_exprs, | ||
input_scales=input_scales, | ||
input_zero_points=input_zero_points, | ||
output_scale=output_tensor.qnn_params['scale'], | ||
output_zero_point=output_tensor.qnn_params['zero_point'], | ||
axis=concatenation_axis) | ||
|
||
# if we have activation fn | ||
if fused_activation_fn != ActivationFunctionType.NONE: | ||
out = self.convert_fused_activation_function(out, fused_activation_fn) | ||
if not output_tensor.qnn_params: | ||
out = self.convert_fused_activation_function(out, fused_activation_fn) | ||
else: | ||
raise tvm.error.OpNotImplemented( | ||
'Operator {} with fused activation is not supported yet.' | ||
.format('qnn.op.concatenate')) | ||
Comment on lines
+459
to
+461
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure if we can have an clamp here to support it... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We also need to handle different types of activations like Relu/Relu6. I have a draft implementation of activation handling. I will send a separate PR for that once I test it with few cases. |
||
return out | ||
|
||
def _convert_elemwise(self, relay_op, op): | ||
|
@@ -557,6 +621,12 @@ def convert_fully_connected(self, op): | |
input_tensor_idx = input_tensor.tensor_idx | ||
weight_tensor = input_tensors[1] | ||
|
||
output_tensors = self.get_output_tensors(op) | ||
assert len(output_tensors) == 1, "output tensors length should be 1" | ||
output_tensor = output_tensors[0] | ||
output_tensor_type = output_tensor.tensor.Type() | ||
output_tensor_type_str = self.get_tensor_type_str(output_tensor_type) | ||
|
||
input_tensor_shape = input_tensor.tensor.ShapeAsNumpy() | ||
weight_tensor_shape = weight_tensor.tensor.ShapeAsNumpy() | ||
|
||
|
@@ -584,7 +654,13 @@ def convert_fully_connected(self, op): | |
weight_value = self.get_tensor_value(weight_tensor) | ||
weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str) | ||
|
||
out = _op.nn.dense(in_expr, weight_expr) | ||
if input_tensor.qnn_params: | ||
out = _qnn.op.dense(in_expr, weight_expr, | ||
input_zero_point=input_tensor.qnn_params['zero_point'], | ||
kernel_zero_point=weight_tensor.qnn_params['zero_point'], | ||
out_dtype='int32') | ||
else: | ||
out = _op.nn.dense(in_expr, weight_expr) | ||
|
||
# if we have bias | ||
if len(input_tensors) == 3: | ||
|
@@ -599,7 +675,23 @@ def convert_fully_connected(self, op): | |
|
||
# If we have fused activations | ||
if fused_activation_fn != ActivationFunctionType.NONE: | ||
out = self.convert_fused_activation_function(out, fused_activation_fn) | ||
if not output_tensor.qnn_params: | ||
out = self.convert_fused_activation_function(out, fused_activation_fn) | ||
else: | ||
raise tvm.error.OpNotImplemented( | ||
'Operator {} with fused activation is not supported yet.' | ||
.format('qnn.op.dense')) | ||
|
||
# Finally if the dense is quantized. Add a requantize at the end. | ||
if output_tensor.qnn_params: | ||
input_scale = input_tensor.qnn_params['scale'] * weight_tensor.qnn_params['scale'] | ||
input_zero_point = 0 | ||
out = _qnn.op.requantize(out, | ||
input_scale=input_scale, | ||
input_zero_point=input_zero_point, | ||
anijain2305 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
output_scale=output_tensor.qnn_params['scale'], | ||
output_zero_point=output_tensor.qnn_params['zero_point'], | ||
out_dtype=output_tensor_type_str) | ||
|
||
return out | ||
|
||
|
@@ -671,6 +763,12 @@ def convert_conv(self, op, conv_type): | |
input_tensor_idx = input_tensor.tensor_idx | ||
weight_tensor = input_tensors[1] | ||
|
||
output_tensors = self.get_output_tensors(op) | ||
assert len(output_tensors) == 1, "output tensors length should be 1" | ||
output_tensor = output_tensors[0] | ||
output_tensor_type = output_tensor.tensor.Type() | ||
output_tensor_type_str = self.get_tensor_type_str(output_tensor_type) | ||
|
||
is_depthwise_conv = False | ||
if conv_type == 'conv2d': | ||
assert op.BuiltinOptionsType() == BuiltinOptions.Conv2DOptions | ||
|
@@ -758,7 +856,14 @@ def convert_conv(self, op, conv_type): | |
raise tvm.error.OpAttributeUnImplemented( | ||
'Padding format {} is not supported for operator Conv.'.format(padding)) | ||
|
||
out = _op.nn.conv2d(data=in_expr, weight=weight_expr, **params) | ||
if input_tensor.qnn_params: | ||
qnn_conv2d_params = dict(params) | ||
qnn_conv2d_params['input_zero_point'] = input_tensor.qnn_params['zero_point'] | ||
qnn_conv2d_params['kernel_zero_point'] = weight_tensor.qnn_params['zero_point'] | ||
qnn_conv2d_params['out_dtype'] = 'int32' | ||
out = _qnn.op.conv2d(in_expr, weight_expr, **qnn_conv2d_params) | ||
else: | ||
out = _op.nn.conv2d(in_expr, weight_expr, **params) | ||
|
||
# if we have bias | ||
if len(input_tensors) == 3: | ||
|
@@ -774,7 +879,23 @@ def convert_conv(self, op, conv_type): | |
|
||
# If we have fused activations | ||
if fused_activation_fn != ActivationFunctionType.NONE: | ||
out = self.convert_fused_activation_function(out, fused_activation_fn) | ||
if not output_tensor.qnn_params: | ||
out = self.convert_fused_activation_function(out, fused_activation_fn) | ||
else: | ||
raise tvm.error.OpNotImplemented( | ||
'Operator {} with fused activation is not supported yet.' | ||
.format('qnn.op.conv2d')) | ||
Comment on lines
+885
to
+887
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's interesting, AFAIK, TFLite model converter always fuses the ReLU family activations to Conv/FC operators. The activation of Inception model you are using is not fused? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, TFLite hosted models do not have any fused activation. There is a fused activation in Object Detection models, but we are far away from supporting object detection (basically dependent on control flow support). |
||
|
||
# Finally if the conv is quantized. Add a requantize at the end. | ||
if output_tensor.qnn_params: | ||
input_scale = input_tensor.qnn_params['scale'] * weight_tensor.qnn_params['scale'] | ||
input_zero_point = 0 | ||
out = _qnn.op.requantize(out, | ||
input_scale=input_scale, | ||
input_zero_point=input_zero_point, | ||
output_scale=output_tensor.qnn_params['scale'], | ||
output_zero_point=output_tensor.qnn_params['zero_point'], | ||
out_dtype=output_tensor_type_str) | ||
Comment on lines
+890
to
+898
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can have a small function, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good suggestion. But, I will skip this one for now. Currently, there are only 2 uses of requantize. Both of which are quite specific - where we have to multiply the input scales to get the new input scale. As we add more operators, and thus more requantize, I will revisit this. |
||
|
||
return out | ||
|
||
|
@@ -879,6 +1000,12 @@ def convert_pool2d(self, op, pool_type): | |
input_tensor = input_tensors[0] | ||
input_tensor_idx = input_tensor.tensor_idx | ||
|
||
output_tensors = self.get_output_tensors(op) | ||
assert len(output_tensors) == 1, "output tensors should be 1" | ||
output_tensor = output_tensors[0] | ||
output_tensor_type = output_tensor.tensor.Type() | ||
output_tensor_type_str = self.get_tensor_type_str(output_tensor_type) | ||
|
||
assert op.BuiltinOptionsType() == BuiltinOptions.Pool2DOptions | ||
op_options = op.BuiltinOptions() | ||
pool2d_options = Pool2DOptions() | ||
|
@@ -909,17 +1036,32 @@ def convert_pool2d(self, op, pool_type): | |
'Padding format {} for operator Pool2D is not supported.'.format(padding)) | ||
|
||
if pool_type == "average": | ||
out = _op.nn.avg_pool2d(in_expr, **params) | ||
if input_tensor.qnn_params: | ||
assert self.has_same_qnn_params(input_tensor, output_tensor), \ | ||
'TFLite avg_pool2dreshape requires input and output scale' \ | ||
'and zero points to be equal' | ||
out = _op.cast(in_expr, dtype="int32") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style consistency of the |
||
out = _op.nn.avg_pool2d(out, **params) | ||
out = _op.cast(out, dtype=output_tensor_type_str) | ||
else: | ||
out = _op.nn.avg_pool2d(in_expr, **params) | ||
elif pool_type == "max": | ||
if input_tensor.qnn_params: | ||
assert self.has_same_qnn_params(input_tensor, output_tensor), \ | ||
"qnn.op.max_pool2d requires input and output qnn params to be same" | ||
out = _op.nn.max_pool2d(in_expr, **params) | ||
else: | ||
raise tvm.error.OpNotImplemented( | ||
'Operator {} is not supported for frontend TFLite.'.format(pool_type + ' pool')) | ||
|
||
# If we have fused activations | ||
if fused_activation_fn != ActivationFunctionType.NONE: | ||
out = self.convert_fused_activation_function(out, fused_activation_fn) | ||
|
||
if input_tensor.qnn_params: | ||
raise tvm.error.OpNotImplemented( | ||
'Operator {} with fused activation is not supported yet.' | ||
.format('qnn.op.pool2d')) | ||
else: | ||
out = self.convert_fused_activation_function(out, fused_activation_fn) | ||
return out | ||
|
||
def convert_pad(self, op): | ||
|
@@ -962,7 +1104,7 @@ def convert_pack(self, op): | |
in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors] | ||
|
||
output_tensors = self.get_output_tensors(op) | ||
assert len(output_tensors) == 1, "output tensors should be 1" | ||
assert len(output_tensors) == 1, "output tensors length should be 1" | ||
|
||
assert op.BuiltinOptionsType() == BuiltinOptions.PackOptions | ||
op_options = op.BuiltinOptions() | ||
|
@@ -1210,4 +1352,5 @@ def from_tflite(model, shape_dict, dtype_dict): | |
outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs] | ||
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) | ||
func = _expr.Function(analysis.free_vars(outputs), outputs) | ||
return _module.Module.from_expr(func), params | ||
mod = _module.Module.from_expr(func) | ||
return mod, params |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -979,6 +979,46 @@ def test_forward_inception_v4_net(): | |
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), | ||
rtol=1e-5, atol=1e-5) | ||
|
||
def test_forward_qnn_inception_v1_net(): | ||
"""Test the Quantized TFLite Inception model.""" | ||
# InceptionV1 | ||
tflite_model_file = tf_testing.get_workload_official( | ||
"https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_224_quant_20181026.tgz", | ||
"inception_v1_224_quant.tflite") | ||
with open(tflite_model_file, "rb") as f: | ||
tflite_model_buf = f.read() | ||
# Checking the labels because the requantize implementation is different between TFLite and | ||
# Relay. This cause final output numbers to mismatch. So, testing accuracy via labels. | ||
np.random.seed(0) | ||
data = np.random.random_integers(low=0, high=128, size=(1, 224, 224, 3)).astype('uint8') | ||
tflite_output = run_tflite_graph(tflite_model_buf, data) | ||
tflite_predictions = np.squeeze(tflite_output) | ||
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] | ||
tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') | ||
tvm_predictions = np.squeeze(tvm_output) | ||
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] | ||
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) | ||
|
||
def test_forward_qnn_mobilenet_v1_net(): | ||
"""Test the Quantized TFLite Mobilenet V1 model.""" | ||
# MobilenetV1 | ||
tflite_model_file = tf_testing.get_workload_official( | ||
"https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz", | ||
"mobilenet_v1_1.0_224_quant.tflite") | ||
with open(tflite_model_file, "rb") as f: | ||
tflite_model_buf = f.read() | ||
# Checking the labels because the requantize implementation is different between TFLite and | ||
# Relay. This cause final output numbers to mismatch. So, testing accuracy via labels. | ||
np.random.seed(0) | ||
data = np.random.random_integers(low=0, high=128, size=(1, 224, 224, 3)).astype('uint8') | ||
tflite_output = run_tflite_graph(tflite_model_buf, data) | ||
tflite_predictions = np.squeeze(tflite_output) | ||
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] | ||
tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') | ||
tvm_predictions = np.squeeze(tvm_output) | ||
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] | ||
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you explain more why we can not compare tflite_out with tvm_out directly like we do in FP32? I think we could get the same result if we have kept the compatibility with TFLite, if not, why we don’t keep the compatibility? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The only difference between TFLite and QNN is rounding. TFLite reference implementation is reverse-engineered from ARM assembly instructions (which perform 2 roundings in general). QNN follows more formal rounding described here. If you are interested, I would encourage to read the inlined comments in the code - https://github.com/dmlc/tvm/blob/068c148f01fcc81c72004512de320ca1cee24dc6/src/relay/qnn/util.cc#L79-L133 There was also a fair amount of discussion which has been summarized here - #3591 (comment). This should also be helpful. This rounding can make small variations in the output, because of which we can not do exact check. I have measured accuracy on TFLite quantized Inception V1-V3 on 50K images using QNN, and observed accuracy parity with TFLite accuracy native execution. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the information! @anijain2305 However, if I remember correctly, when I review the code of Alright, if we change the default value later, In my opinion, I think we should pass
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TFlite rounding is not exactly formal
The key idea is that the above two functions map to ARM assembly instructions. But, in this process, they perform approximation that make it little different from formal GNU C rounding. The reasoning behind doing that should be to get performance. In QNN, we decided to follow GNU C rounding, as thats more formal and can work across more frameworks (like MxNet, PyTorch and not just TFLite). However, if one wants, one can implement a TFLite rounding. The code is modular enough to add one more rounding option. I think you have good comfortability with TFLite, so I would encourage you to write TFLite rounding if you get time :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I fully understand your point across more frameworks. My previous point is to emphasize we could provide one option for our TFLite frontend parser so that we could keep the compatibility with TFLite, we don't need to change any default value in our existing code, which could be used for MXNet, PyTorch and so on. Otherwise, the users will confuse why our result is not the same as TFLite when they use TVM. Keep the compatibility with TFLite should be the first priority when we parse TFLite model. So I think we shouldn't leave this to implement by TVM users. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. However, as you said, it doesn't affect the code structure. If you agree the point of compatibility point, I think we could make this PR in firstly and open one issue reminding us we should do this thing. If you don't have spare time, I could help to manage it next. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we can open an issue to track the TFLite rounding. It would not affect code structure. We can set the value of the rounding in TFLite parser, hiding it from the TVM users. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK. Could you help to open this this to track it? And you could assign this task to me. |
||
|
||
####################################################################### | ||
# SSD Mobilenet | ||
# ------------- | ||
|
@@ -1048,3 +1088,8 @@ def test_forward_ssd_mobilenet_v1(): | |
test_forward_inception_v3_net() | ||
test_forward_inception_v4_net() | ||
test_forward_ssd_mobilenet_v1() | ||
|
||
# End to End quantized | ||
# TODO - MobilenetV2 fails for now. Remove when fixed. | ||
test_forward_qnn_inception_v1_net() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let us add Mobilenet V2 quant model if this PR could run it. Because it is very popular and often compared across different framework. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense. Till now I have been focussing on the server side, and thus I used Inception. Can this PR go in first and I can send a PR for mobilenet V2 separately? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have no problem. However, I have one question that do we have any TODO for Mobilenet V2 quant model? From my knowledge, we should work well on Mobilenet V2 quant model using this PR and just add test code of Mobilenet V2 quant model. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think there is. I want to do a full 50k image test before adding MobileNetV2. I have done that for InceptionV1 and ensured that its accuracy is same as that provided in TFLite info. I will setup the pre-processing for mobilenet next week, and check end-to-end accuracy, and send a separate PR. What do you say? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we don't need to do it you mention. Because we are parsing TFLite quant model and TFLite framework is one baseline. We could treat TFLite quant model like TFLite FP32 model and compare the result between us and TFLite framework. They are no difference. I think we only need to do end-to-end accuracy when we do quantization in TVM inside or we have strong reason to implement ops but don't keep compatibility with TFLite framework. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added MobileNet V1. MobilenetV2 fails with a segfault in LLVM. Will debug. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the detail information of LLVM report? Try the latest? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the whole stack trace. Happens while the graph runtime create. Relay build has already passed.
|
||
test_forward_qnn_mobilenet_v1_net() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems generates float output for softmax, would it be better if we quantize it back to uint8 since the semantic of a uint8 model requires an uint8 output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would indeed be better. Doing that.