Skip to content

Commit

Permalink
[TFLite] Enable int64 biases for int16 quantized operators (apache#12042
Browse files Browse the repository at this point in the history
)

This enables int64 biases for quantized fully connected, requantize
and transpose convolution in TFLite networks. It goes on top of existing
int16 support for TFLite frontend.

Add a test case using DS_CNN int16 quantized.
  • Loading branch information
leandron authored Nov 15, 2022
1 parent 647be2b commit 034dc67
Show file tree
Hide file tree
Showing 7 changed files with 329 additions and 224 deletions.
6 changes: 3 additions & 3 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1966,7 +1966,7 @@ def convert_fully_connected(self, op):
input_scale=input_tensor.qnn_params["scale"],
kernel_scale=weight_tensor.qnn_params["scale"],
units=weight_shape[0],
out_dtype="int32",
out_dtype="int64" if output_tensor_type_str == "int16" else "int32",
)
else:
out = _op.nn.dense(in_expr, weight_expr, units=weight_shape[0])
Expand All @@ -1977,7 +1977,7 @@ def convert_fully_connected(self, op):
if bias_tensor.tensor_idx != -1:
bias_tensor_type = bias_tensor.tensor.Type()
# bias tensor type should be INT32 (quantization) or FLOAT32
assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
assert bias_tensor_type in (TensorType.INT32, TensorType.INT64, TensorType.FLOAT32)
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
if self.has_expr(bias_tensor.tensor_idx):
bias_expr = self.get_expr(bias_tensor.tensor_idx)
Expand Down Expand Up @@ -3175,7 +3175,7 @@ def convert_transpose_conv(self, op):
bias_tensor = input_tensors[3]
bias_tensor_type = bias_tensor.tensor.Type()
# bias tensor type should be INT32 (quantization) or FLOAT32
assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
assert bias_tensor_type in (TensorType.INT32, TensorType.INT64, TensorType.FLOAT32)
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
if self.has_expr(bias_tensor.tensor_idx):
bias_expr = self.get_expr(bias_tensor.tensor_idx)
Expand Down
10 changes: 6 additions & 4 deletions src/relay/qnn/op/convolution_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,14 @@ bool QnnConv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs
if (data == nullptr || weight == nullptr) return false;
const auto* param = attrs.as<Conv2DTransposeAttrs>();
ICHECK(param != nullptr) << "Conv2DTransposeAttrs cannot be nullptr.";
ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8))
<< "Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype;
ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8) ||
data->dtype == DataType::Int(16) || data->dtype == DataType::UInt(16))
<< "Expected qnn conv2d type(int8, uint8, int16) for input but was " << data->dtype;
ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8))
<< "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype;
ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32))
<< "Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype;
ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32) ||
data->dtype == DataType::Int(64))
<< "Expected qnn conv2d type(int16, int32, int64) for output but was " << param->out_dtype;
ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";

// Check the types of scale and zero points.
Expand Down
10 changes: 6 additions & 4 deletions src/relay/qnn/op/dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,14 @@ bool QnnDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (data == nullptr || weight == nullptr) return false;
const auto* param = attrs.as<DenseAttrs>();
ICHECK(param != nullptr) << "DenseAttrs cannot be nullptr.";
ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8))
<< "Expected quantized dense type(int8, uint8) for input but was " << data->dtype;
ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8) ||
data->dtype == DataType::Int(16) || data->dtype == DataType::UInt(16))
<< "Expected quantized dense type(int8, uint8, int16, uint16) for input but was "
<< data->dtype;
ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8))
<< "Expected quantized dense type(int8, uint8) for weight but was " << weight->dtype;
ICHECK(param->out_dtype == DataType::Int(32))
<< "Expected quantized dense type(int32) for output but was " << param->out_dtype;
ICHECK(param->out_dtype == DataType::Int(32) || param->out_dtype == DataType::Int(64))
<< "Expected quantized dense type(int32, int64) for output but was " << param->out_dtype;

// Check the types of scale and zero points.
for (size_t i = 2; i < 5; ++i) {
Expand Down
5 changes: 3 additions & 2 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,9 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}
const auto in_dtype = data->dtype;
ICHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) ||
in_dtype == DataType::Int(32) || in_dtype == DataType::Int(64))
<< "Input type should be one of [int8, uint8, int32, int64] but was " << in_dtype;
in_dtype == DataType::Int(16) || in_dtype == DataType::Int(32) ||
in_dtype == DataType::Int(64))
<< "Input type should be one of [int8, uint8, int16, int32, int64] but was " << in_dtype;

const RequantizeAttrs* requantize_attrs = attrs.as<RequantizeAttrs>();
int axis = requantize_attrs->axis;
Expand Down
4 changes: 2 additions & 2 deletions tests/python/contrib/test_ethosn/test_convert_equivalents.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def expected():
@requires_ethosn
@pytest.mark.parametrize(
"dtype,shape,constant_shape",
[("int16", (1, 16, 12, 4), None)],
[("float32", (1, 16, 12, 4), None)],
)
def test_unsupported_multiply_to_reinterpret_quantize(dtype, shape, constant_shape):
"""
Expand Down Expand Up @@ -445,7 +445,7 @@ def expected():
@pytest.mark.parametrize(
"dtype,shape,constant_shape",
[
("int16", (1, 16, 12, 4), None),
("float32", (1, 16, 12, 4), None),
],
)
def test_unsupported_add_to_reinterpret_quantize(dtype, shape, constant_shape):
Expand Down
23 changes: 23 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4878,6 +4878,28 @@ def representative_dataset():
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)


def test_forward_ds_cnn_int16():
"""Test DS_CNN int16 quantized model"""
tflite_model_file = download_testdata(
"https://github.com/ARM-software/ML-zoo/blob/48f458af1e9065d9aad2ad94d24b58d6e7c00817/"
"models/keyword_spotting/ds_cnn_small/tflite_int16/ds_cnn_quantized.tflite?raw=true",
"ds_cnn_quantized_int16.tflite",
)

with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()

data = np.random.uniform(size=(1, 490)).astype("int16")

tflite_output = run_tflite_graph(tflite_model_buf, data)
tflite_predictions = np.squeeze(tflite_output)
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
tvm_output = run_tvm_graph(tflite_model_buf, data, "serving_default_input:0")
tvm_predictions = np.squeeze(tvm_output)
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)


#######################################################################
# Unidirectional Sequence LSTM
# ---------------------
Expand Down Expand Up @@ -5250,3 +5272,4 @@ def test_forward_nms_v5():
test_forward_tflite_float16()

test_forward_tflite_int16()
test_forward_ds_cnn_int16()
Loading

0 comments on commit 034dc67

Please sign in to comment.