Skip to content

Commit

Permalink
[FRONTEND][TFLITE][BugFix] Fix int16 transpose conv loading (apache#1…
Browse files Browse the repository at this point in the history
…5173)

Loading int16 conv transpose op in tflite model currently
fails because output type is not int64.

This patch adjusts output type to int64 for int16 quantized
transpose convolution operation. In addition, one typo in
QnnConv2DTransposeRel is fixed.

Test script is also included to evaluate the loading
of int16 quantized transpose convolution op.

Co-authored-by: Wooseok <skyeyews@gmail.com>
  • Loading branch information
wooseok-cadence and Wooseok authored Jun 29, 2023
1 parent 99d72fd commit 22e592b
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 2 deletions.
3 changes: 2 additions & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3300,6 +3300,7 @@ def convert_transpose_conv(self, op):
kernel_zero_point = weights_tensor.qnn_params["zero_point"]
input_scale = input_tensor.qnn_params["scale"]
kernel_scale = weights_tensor.qnn_params["scale"]
out_dtype = "int64" if output_tensor_type_str == "int16" else "int32"
out = _qnn.op.conv2d_transpose(
in_expr,
weight_expr_iohw,
Expand All @@ -3313,7 +3314,7 @@ def convert_transpose_conv(self, op):
kernel_size=(int(kernel_h), int(kernel_w)),
data_layout="NHWC",
kernel_layout="IOHW",
out_dtype="int32",
out_dtype=out_dtype,
)
else:
out = _op.nn.conv2d_transpose(
Expand Down
2 changes: 1 addition & 1 deletion src/relay/qnn/op/convolution_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ bool QnnConv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs
ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8))
<< "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype;
ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32) ||
data->dtype == DataType::Int(64))
param->out_dtype == DataType::Int(64))
<< "Expected qnn conv2d type(int16, int32, int64) for output but was " << param->out_dtype;
ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";

Expand Down
80 changes: 80 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1648,6 +1648,86 @@ def test_forward_transpose_conv():
)


def _test_tflite2_quantized_transpose_conv(
input_shape,
kernel_shape,
filters,
padding="valid",
strides=(1, 1),
data_format=None,
int_quant_dtype=tf.int8,
):
"""One iteration of TFLite2 quantized tranpose conv with given shapes and attributes"""
data_format = "channels_last" if data_format == "NHWC" else "channels_first"
data = np.random.uniform(0, 1, input_shape).astype("float32")
_ = np.random.uniform(0, 1, kernel_shape).astype("float32")

data_in = tf.keras.layers.Input(shape=data.shape[1:], batch_size=1)
transpose_conv = tf.keras.layers.Conv2DTranspose(
filters=filters,
kernel_size=(kernel_shape[0], kernel_shape[1]),
padding=padding,
strides=strides,
use_bias=True,
)(data_in)
keras_model = tf.keras.models.Model(data_in, transpose_conv)

# To create quantized values with dynamic range of activations, needs representative dataset
def representative_data_gen():
for _ in range(1):
yield [data]

tflite_model_quant = _quantize_keras_model(
keras_model,
representative_data_gen,
is_float_input=True,
is_float_output=True,
int_quant_dtype=int_quant_dtype,
)

# TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
try:
import tflite.Model

tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_quant, 0)
except AttributeError:
import tflite

tflite_model = tflite.Model.GetRootAsModel(tflite_model_quant, 0)
except ImportError as exc:
raise ImportError("The tflite package must be installed") from exc

subgraph = tflite_model.Subgraphs(0)
model_input = subgraph.InputsAsNumpy()
input_node = subgraph.Tensors(model_input).Name().decode("utf-8")

tflite_output = run_tflite_graph(tflite_model_quant, data)

if tf.__version__ < LooseVersion("2.9"):
input_node = data_in.name.replace(":0", "")
else:
input_node = "serving_default_" + data_in.name + ":0"

tvm_output = run_tvm_graph(tflite_model_quant, data, input_node)
tvm.testing.assert_allclose(
np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-2, atol=1e-2
)


def test_forward_quantized_transpose_conv():
"""Quantized convolution"""
for int_quant_dtype in [tf.int8, tf.int16]:
_test_tflite2_quantized_transpose_conv(
(1, 1, 5, 64),
(3, 3),
64,
padding="same",
strides=(1, 2),
data_format="NHWC",
int_quant_dtype=int_quant_dtype,
)


#######################################################################
# Reshape
# -------
Expand Down

0 comments on commit 22e592b

Please sign in to comment.