Skip to content

Commit

Permalink
[TFLite] Enable int64 biases for int16 quantized operators
Browse files Browse the repository at this point in the history
This enables int64 biases for quantized fully connected, requantize
and transpose convolution in TFLite networks. It goes on top of existing
int16 support for TFLite frontend.
  • Loading branch information
leandron committed Jul 8, 2022
1 parent c412450 commit 7eb64a3
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 13 deletions.
6 changes: 3 additions & 3 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1939,7 +1939,7 @@ def convert_fully_connected(self, op):
input_scale=input_tensor.qnn_params["scale"],
kernel_scale=weight_tensor.qnn_params["scale"],
units=weight_shape[0],
out_dtype="int32",
out_dtype="int64" if output_tensor_type_str == "int16" else "int32",
)
else:
out = _op.nn.dense(in_expr, weight_expr, units=weight_shape[0])
Expand All @@ -1950,7 +1950,7 @@ def convert_fully_connected(self, op):
if bias_tensor.tensor_idx != -1:
bias_tensor_type = bias_tensor.tensor.Type()
# bias tensor type should be INT32 (quantization) or FLOAT32
assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
assert bias_tensor_type in (TensorType.INT32, TensorType.INT64, TensorType.FLOAT32)
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
if self.has_expr(bias_tensor.tensor_idx):
bias_expr = self.get_expr(bias_tensor.tensor_idx)
Expand Down Expand Up @@ -3145,7 +3145,7 @@ def convert_transpose_conv(self, op):
bias_tensor = input_tensors[3]
bias_tensor_type = bias_tensor.tensor.Type()
# bias tensor type should be INT32 (quantization) or FLOAT32
assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
assert bias_tensor_type in (TensorType.INT32, TensorType.INT64, TensorType.FLOAT32)
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
if self.has_expr(bias_tensor.tensor_idx):
bias_expr = self.get_expr(bias_tensor.tensor_idx)
Expand Down
10 changes: 6 additions & 4 deletions src/relay/qnn/op/convolution_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,14 @@ bool QnnConv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs
if (data == nullptr || weight == nullptr) return false;
const auto* param = attrs.as<Conv2DTransposeAttrs>();
ICHECK(param != nullptr) << "Conv2DTransposeAttrs cannot be nullptr.";
ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8))
<< "Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype;
ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8) ||
data->dtype == DataType::Int(16) || data->dtype == DataType::UInt(16))
<< "Expected qnn conv2d type(int8, uint8, int16) for input but was " << data->dtype;
ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8))
<< "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype;
ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32))
<< "Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype;
ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32) ||
data->dtype == DataType::Int(64))
<< "Expected qnn conv2d type(int16, int32, int64) for output but was " << param->out_dtype;
ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";

// Check the types of scale and zero points.
Expand Down
10 changes: 6 additions & 4 deletions src/relay/qnn/op/dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,14 @@ bool QnnDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (data == nullptr || weight == nullptr) return false;
const auto* param = attrs.as<DenseAttrs>();
ICHECK(param != nullptr) << "DenseAttrs cannot be nullptr.";
ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8))
<< "Expected quantized dense type(int8, uint8) for input but was " << data->dtype;
ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8) ||
data->dtype == DataType::Int(16) || data->dtype == DataType::UInt(16))
<< "Expected quantized dense type(int8, uint8, int16, uint16) for input but was "
<< data->dtype;
ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8))
<< "Expected quantized dense type(int8, uint8) for weight but was " << weight->dtype;
ICHECK(param->out_dtype == DataType::Int(32))
<< "Expected quantized dense type(int32) for output but was " << param->out_dtype;
ICHECK(param->out_dtype == DataType::Int(32) || param->out_dtype == DataType::Int(64))
<< "Expected quantized dense type(int32, int64) for output but was " << param->out_dtype;

// Check the types of scale and zero points.
for (size_t i = 2; i < 5; ++i) {
Expand Down
5 changes: 3 additions & 2 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,9 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}
const auto in_dtype = data->dtype;
ICHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) ||
in_dtype == DataType::Int(32) || in_dtype == DataType::Int(64))
<< "Input type should be one of [int8, uint8, int32, int64] but was " << in_dtype;
in_dtype == DataType::Int(16) || in_dtype == DataType::Int(32) ||
in_dtype == DataType::Int(64))
<< "Input type should be one of [int8, uint8, int16, int32, int64] but was " << in_dtype;

const RequantizeAttrs* requantize_attrs = attrs.as<RequantizeAttrs>();
int axis = requantize_attrs->axis;
Expand Down

0 comments on commit 7eb64a3

Please sign in to comment.