Skip to content

Commit e518df2

Browse files
anijain2305alexwong
authored andcommitted
[TFLite] Using real image for QNN testing. (apache#4816)
* [TFLite] Using real image for QNN testing. * Setting seed for SSD mobilenet for fixed input. * Support quantized Pad op. * Remove unnnecessary line. * Ina comments.
1 parent eb07246 commit e518df2

File tree

2 files changed

+72
-19
lines changed

2 files changed

+72
-19
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,10 +1179,14 @@ def convert_conv(self, op, conv_type):
11791179
pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w)
11801180
do_pad = not (pad_top == 0 and pad_bottom == 0 and pad_left == 0 and pad_right == 0)
11811181
if do_pad:
1182+
pad_value = 0
1183+
if input_tensor.qnn_params:
1184+
pad_value = get_scalar_from_constant(input_tensor.qnn_params['zero_point'])
11821185
in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0),
11831186
(pad_top, pad_bottom),
11841187
(pad_left, pad_right),
1185-
(0, 0)))
1188+
(0, 0)), pad_value=float(pad_value))
1189+
11861190
else:
11871191
raise tvm.error.OpAttributeUnImplemented(
11881192
'Padding format {} is not supported for operator Conv.'.format(padding))
@@ -1476,8 +1480,19 @@ def convert_pad(self, op):
14761480
# convert list of lists to tuple of tuples
14771481
paddings = tuple(tuple(l) for l in pad_list)
14781482

1479-
# Use default pad_value 0 because TFLite PAD does not support constant_values parameter
1480-
out = _op.nn.pad(in_expr, paddings)
1483+
# Set the pad value
1484+
pad_value = 0
1485+
if input_tensor.qnn_params:
1486+
# Check that input and output tensor have same qnn params.
1487+
output_tensors = self.get_output_tensors(op)
1488+
output_tensor = output_tensors[0]
1489+
assert self.has_same_qnn_params(input_tensor, output_tensor), \
1490+
"TFLite reshape requires input and output scale and zero points to be equal"
1491+
1492+
# The pad value for quantized pad is the input zero point.
1493+
pad_value = float(input_tensor.qnn_params['zero_point'].data.asnumpy())
1494+
1495+
out = _op.nn.pad(in_expr, pad_width=paddings, pad_value=pad_value)
14811496
return out
14821497

14831498
def convert_mirror_pad(self, op):

tests/python/frontend/tflite/test_forward.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242
import tvm.relay.testing.tf as tf_testing
4343
from packaging import version as package_version
4444

45+
from PIL import Image
46+
import os
47+
4548
#######################################################################
4649
# Generic run functions for TVM & TFLite
4750
# --------------------------------------
@@ -50,6 +53,20 @@ def convert_to_list(x):
5053
x = [x]
5154
return x
5255

56+
57+
#######################################################################
58+
# Get a real image for e2e testing.
59+
# --------------------------------------
60+
def get_real_image(im_height, im_width):
61+
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
62+
img_name = 'elephant-299.jpg'
63+
image_url = os.path.join(repo_base, img_name)
64+
img_path = download_testdata(image_url, img_name, module='data')
65+
image = Image.open(img_path).resize((im_height, im_width))
66+
x = np.array(image).astype('uint8')
67+
data = np.reshape(x, (1, im_height, im_width, 3))
68+
return data
69+
5370
def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm',
5471
out_names=None):
5572
""" Generic function to compile on relay and execute on tvm """
@@ -1139,16 +1156,28 @@ def test_forward_squeeze():
11391156
# Pad
11401157
# ---
11411158

1142-
def _test_pad(data, mode="CONSTANT"):
1159+
def _test_pad(data, mode="CONSTANT", quantized=False):
11431160
""" One iteration of PAD """
11441161

11451162
assert len(data) == 2
11461163

11471164
# Test with tensor and constant
11481165
with tf.Graph().as_default():
1149-
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')]
1150-
out = array_ops.pad(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode)
1151-
compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])
1166+
in_data = [array_ops.placeholder(shape=data[0].shape, dtype='float32', name='in')]
1167+
1168+
if quantized:
1169+
# fake_quant will keep the tensors in float32 until the conversion in the session
1170+
input_range = {'inq_0': (-100, 100)}
1171+
inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0],
1172+
min=-100,
1173+
max=100,
1174+
name="inq_0")]
1175+
out = array_ops.pad(inq_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode)
1176+
compare_tflite_with_tvm([data[0]], ['inq_0:0'], inq_data, [out], quantized=True,
1177+
input_range=input_range)
1178+
else:
1179+
out = array_ops.pad(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode)
1180+
compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])
11521181

11531182

11541183
def test_forward_pad():
@@ -1165,6 +1194,8 @@ def test_forward_pad():
11651194
np.array([[1, 1], [2, 2]], dtype=np.int32)], mode="REFLECT")
11661195
_test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
11671196
np.array([[1, 1], [2, 2]], dtype=np.int32)], mode="SYMMETRIC")
1197+
_test_pad([np.arange(0, 256, dtype=np.uint8).reshape((1, 256)),
1198+
np.array([[1, 1], [2, 2]], dtype=np.int32)], quantized=True)
11681199

11691200

11701201
#######################################################################
@@ -1425,10 +1456,12 @@ def test_forward_qnn_inception_v1_net():
14251456
"inception_v1_224_quant.tflite")
14261457
with open(tflite_model_file, "rb") as f:
14271458
tflite_model_buf = f.read()
1428-
# Checking the labels because the requantize implementation is different between TFLite and
1429-
# Relay. This cause final output numbers to mismatch. So, testing accuracy via labels.
1430-
np.random.seed(0)
1431-
data = np.random.random_integers(low=0, high=128, size=(1, 224, 224, 3)).astype('uint8')
1459+
1460+
# Test image. Checking the labels because the requantize implementation is different between
1461+
# TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via
1462+
# labels. Also, giving a real image, instead of random inputs.
1463+
data = get_real_image(224, 224)
1464+
14321465
tflite_output = run_tflite_graph(tflite_model_buf, data)
14331466
tflite_predictions = np.squeeze(tflite_output)
14341467
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
@@ -1445,10 +1478,12 @@ def test_forward_qnn_mobilenet_v1_net():
14451478
"mobilenet_v1_1.0_224_quant.tflite")
14461479
with open(tflite_model_file, "rb") as f:
14471480
tflite_model_buf = f.read()
1448-
# Checking the labels because the requantize implementation is different between TFLite and
1449-
# Relay. This cause final output numbers to mismatch. So, testing accuracy via labels.
1450-
np.random.seed(0)
1451-
data = np.random.random_integers(low=0, high=128, size=(1, 224, 224, 3)).astype('uint8')
1481+
1482+
# Test image. Checking the labels because the requantize implementation is different between
1483+
# TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via
1484+
# labels. Also, giving a real image, instead of random inputs.
1485+
data = get_real_image(224, 224)
1486+
14521487
tflite_output = run_tflite_graph(tflite_model_buf, data)
14531488
tflite_predictions = np.squeeze(tflite_output)
14541489
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
@@ -1465,10 +1500,12 @@ def test_forward_qnn_mobilenet_v2_net():
14651500
"mobilenet_v2_1.0_224_quant.tflite")
14661501
with open(tflite_model_file, "rb") as f:
14671502
tflite_model_buf = f.read()
1468-
# Checking the labels because the requantize implementation is different between TFLite and
1469-
# Relay. This cause final output numbers to mismatch. So, testing accuracy via labels.
1470-
np.random.seed(0)
1471-
data = np.random.random_integers(low=0, high=128, size=(1, 224, 224, 3)).astype('uint8')
1503+
1504+
# Test image. Checking the labels because the requantize implementation is different between
1505+
# TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via
1506+
# labels. Also, giving a real image, instead of random inputs.
1507+
data = get_real_image(224, 224)
1508+
14721509
tflite_output = run_tflite_graph(tflite_model_buf, data)
14731510
tflite_predictions = np.squeeze(tflite_output)
14741511
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
@@ -1489,6 +1526,7 @@ def test_forward_ssd_mobilenet_v1():
14891526
"ssd_mobilenet_v1_coco_2018_01_28_nopp.tflite")
14901527
with open(tflite_model_file, "rb") as f:
14911528
tflite_model_buf = f.read()
1529+
np.random.seed(0)
14921530
data = np.random.uniform(size=(1, 300, 300, 3)).astype('float32')
14931531
tflite_output = run_tflite_graph(tflite_model_buf, data)
14941532
tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=2)

0 commit comments

Comments
 (0)