diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 1915eb9322ff..3d2f4a2f25e6 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1966,7 +1966,7 @@ def convert_fully_connected(self, op): input_scale=input_tensor.qnn_params["scale"], kernel_scale=weight_tensor.qnn_params["scale"], units=weight_shape[0], - out_dtype="int32", + out_dtype="int64" if output_tensor_type_str == "int16" else "int32", ) else: out = _op.nn.dense(in_expr, weight_expr, units=weight_shape[0]) @@ -1977,7 +1977,7 @@ def convert_fully_connected(self, op): if bias_tensor.tensor_idx != -1: bias_tensor_type = bias_tensor.tensor.Type() # bias tensor type should be INT32 (quantization) or FLOAT32 - assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32) + assert bias_tensor_type in (TensorType.INT32, TensorType.INT64, TensorType.FLOAT32) bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type) if self.has_expr(bias_tensor.tensor_idx): bias_expr = self.get_expr(bias_tensor.tensor_idx) @@ -3175,7 +3175,7 @@ def convert_transpose_conv(self, op): bias_tensor = input_tensors[3] bias_tensor_type = bias_tensor.tensor.Type() # bias tensor type should be INT32 (quantization) or FLOAT32 - assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32) + assert bias_tensor_type in (TensorType.INT32, TensorType.INT64, TensorType.FLOAT32) bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type) if self.has_expr(bias_tensor.tensor_idx): bias_expr = self.get_expr(bias_tensor.tensor_idx) diff --git a/src/relay/qnn/op/convolution_transpose.cc b/src/relay/qnn/op/convolution_transpose.cc index 6163e1c20429..951c1bdfb051 100644 --- a/src/relay/qnn/op/convolution_transpose.cc +++ b/src/relay/qnn/op/convolution_transpose.cc @@ -93,12 +93,14 @@ bool QnnConv2DTransposeRel(const Array& types, int num_inputs, const Attrs if (data == nullptr || weight == nullptr) return false; const auto* param = attrs.as(); ICHECK(param != nullptr) << "Conv2DTransposeAttrs cannot be nullptr."; - ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8)) - << "Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype; + ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8) || + data->dtype == DataType::Int(16) || data->dtype == DataType::UInt(16)) + << "Expected qnn conv2d type(int8, uint8, int16) for input but was " << data->dtype; ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8)) << "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype; - ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32)) - << "Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype; + ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32) || + data->dtype == DataType::Int(64)) + << "Expected qnn conv2d type(int16, int32, int64) for output but was " << param->out_dtype; ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; // Check the types of scale and zero points. diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index adaf509e7daf..09d51e3c9ce7 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -47,12 +47,14 @@ bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, if (data == nullptr || weight == nullptr) return false; const auto* param = attrs.as(); ICHECK(param != nullptr) << "DenseAttrs cannot be nullptr."; - ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8)) - << "Expected quantized dense type(int8, uint8) for input but was " << data->dtype; + ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8) || + data->dtype == DataType::Int(16) || data->dtype == DataType::UInt(16)) + << "Expected quantized dense type(int8, uint8, int16, uint16) for input but was " + << data->dtype; ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8)) << "Expected quantized dense type(int8, uint8) for weight but was " << weight->dtype; - ICHECK(param->out_dtype == DataType::Int(32)) - << "Expected quantized dense type(int32) for output but was " << param->out_dtype; + ICHECK(param->out_dtype == DataType::Int(32) || param->out_dtype == DataType::Int(64)) + << "Expected quantized dense type(int32, int64) for output but was " << param->out_dtype; // Check the types of scale and zero points. for (size_t i = 2; i < 5; ++i) { diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 1614652719c6..e199ea27f1e4 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -480,8 +480,9 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, } const auto in_dtype = data->dtype; ICHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) || - in_dtype == DataType::Int(32) || in_dtype == DataType::Int(64)) - << "Input type should be one of [int8, uint8, int32, int64] but was " << in_dtype; + in_dtype == DataType::Int(16) || in_dtype == DataType::Int(32) || + in_dtype == DataType::Int(64)) + << "Input type should be one of [int8, uint8, int16, int32, int64] but was " << in_dtype; const RequantizeAttrs* requantize_attrs = attrs.as(); int axis = requantize_attrs->axis; diff --git a/tests/python/contrib/test_ethosn/test_convert_equivalents.py b/tests/python/contrib/test_ethosn/test_convert_equivalents.py index 77777293729c..a3e48f4424ad 100644 --- a/tests/python/contrib/test_ethosn/test_convert_equivalents.py +++ b/tests/python/contrib/test_ethosn/test_convert_equivalents.py @@ -227,7 +227,7 @@ def expected(): @requires_ethosn @pytest.mark.parametrize( "dtype,shape,constant_shape", - [("int16", (1, 16, 12, 4), None)], + [("float32", (1, 16, 12, 4), None)], ) def test_unsupported_multiply_to_reinterpret_quantize(dtype, shape, constant_shape): """ @@ -445,7 +445,7 @@ def expected(): @pytest.mark.parametrize( "dtype,shape,constant_shape", [ - ("int16", (1, 16, 12, 4), None), + ("float32", (1, 16, 12, 4), None), ], ) def test_unsupported_add_to_reinterpret_quantize(dtype, shape, constant_shape): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 7b2bd60d8a20..877406ae2a64 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -4878,6 +4878,28 @@ def representative_dataset(): tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) +def test_forward_ds_cnn_int16(): + """Test DS_CNN int16 quantized model""" + tflite_model_file = download_testdata( + "https://github.com/ARM-software/ML-zoo/blob/48f458af1e9065d9aad2ad94d24b58d6e7c00817/" + "models/keyword_spotting/ds_cnn_small/tflite_int16/ds_cnn_quantized.tflite?raw=true", + "ds_cnn_quantized_int16.tflite", + ) + + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + + data = np.random.uniform(size=(1, 490)).astype("int16") + + tflite_output = run_tflite_graph(tflite_model_buf, data) + tflite_predictions = np.squeeze(tflite_output) + tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] + tvm_output = run_tvm_graph(tflite_model_buf, data, "serving_default_input:0") + tvm_predictions = np.squeeze(tvm_output) + tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] + tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) + + ####################################################################### # Unidirectional Sequence LSTM # --------------------- @@ -5250,3 +5272,4 @@ def test_forward_nms_v5(): test_forward_tflite_float16() test_forward_tflite_int16() + test_forward_ds_cnn_int16() diff --git a/tests/python/relay/test_op_qnn_requantize.py b/tests/python/relay/test_op_qnn_requantize.py index 64306476dfe9..1dee1f5b619c 100644 --- a/tests/python/relay/test_op_qnn_requantize.py +++ b/tests/python/relay/test_op_qnn_requantize.py @@ -23,6 +23,7 @@ roundings = ["UPWARD", "TONEAREST"] compute_dtypes = ["float32", "float64", "int64"] +out_dtypes = ["int8", "int16"] def verify(mod, goldens, target="llvm"): @@ -83,17 +84,18 @@ def test_same_scale(): golden_output = golden_data for compute_dtype in compute_dtypes: for rounding in roundings: - mod = get_mod( - data_shape=(200,), - data_dtype="int32", - out_dtype="int8", - input_scale=0.5, - output_scale=0.5, - rounding=rounding, - compute_dtype=compute_dtype, - ) - assert "right_shift" not in mod.astext() - verify(mod, (golden_data, golden_output)) + for qnn_out_dtype in out_dtypes: + mod = get_mod( + data_shape=(200,), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=0.5, + output_scale=0.5, + rounding=rounding, + compute_dtype=compute_dtype, + ) + assert "right_shift" not in mod.astext() + verify(mod, (golden_data, golden_output)) def test_scalar_same_scale(): @@ -102,75 +104,77 @@ def test_scalar_same_scale(): golden_output = golden_data for compute_dtype in compute_dtypes: for rounding in roundings: - mod = get_mod( - data_shape=(), - data_dtype="int32", - out_dtype="int8", - input_scale=0.5, - output_scale=0.5, - rounding=rounding, - compute_dtype=compute_dtype, - ) - assert "right_shift" not in mod.astext() - verify(mod, (golden_data, golden_output)) + for qnn_out_dtype in out_dtypes: + mod = get_mod( + data_shape=(), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=0.5, + output_scale=0.5, + rounding=rounding, + compute_dtype=compute_dtype, + ) + assert "right_shift" not in mod.astext() + verify(mod, (golden_data, golden_output)) def test_downscale(): for compute_dtype in compute_dtypes: for rounding in roundings: - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=16, - rounding=rounding, - compute_dtype=compute_dtype, - ) + for qnn_out_dtype in out_dtypes: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=1, + output_scale=16, + rounding=rounding, + compute_dtype=compute_dtype, + ) - # Try positive values - # 8 corresponds to 0.5, resulting in 1 - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.repeat([0, 1, 2], [8, 16, 8]) - verify(mod, (golden_data, golden_output)) + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) - # Try negative values - # -8 corresponds to -0.5. For UPWARD, this is 0 - golden_data = np.arange(0, -32, -1).astype("int32") - if rounding == "UPWARD": - golden_output = np.repeat([0, -1, -2], [9, 16, 7]) - else: - golden_output = np.repeat([0, -1, -2], [8, 16, 8]) - verify(mod, (golden_data, golden_output)) + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype("int32") + if rounding == "UPWARD": + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + else: + golden_output = np.repeat([0, -1, -2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) - # Try a different scale - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=4, - rounding=rounding, - ) + # Try a different scale + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=1, + output_scale=4, + rounding=rounding, + ) - # Try positive values - # 2I corresponds to 0.5, resulting in 1 - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8], [2, 4, 4, 4, 4, 4, 4, 4, 2]) - verify(mod, (golden_data, golden_output)) + # Try positive values + # 2I corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8], [2, 4, 4, 4, 4, 4, 4, 4, 2]) + verify(mod, (golden_data, golden_output)) - # Try negative values - # -8 corresponds to -0.5. For UPWARD, this is 0 - golden_data = np.arange(0, -32, -1).astype("int32") - if rounding == "UPWARD": - golden_output = np.repeat( - [0, -1, -2, -3, -4, -5, -6, -7, -8], [3, 4, 4, 4, 4, 4, 4, 4, 1] - ) - else: - golden_output = np.repeat( - [0, -1, -2, -3, -4, -5, -6, -7, -8], [2, 4, 4, 4, 4, 4, 4, 4, 2] - ) - verify(mod, (golden_data, golden_output)) + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype("int32") + if rounding == "UPWARD": + golden_output = np.repeat( + [0, -1, -2, -3, -4, -5, -6, -7, -8], [3, 4, 4, 4, 4, 4, 4, 4, 1] + ) + else: + golden_output = np.repeat( + [0, -1, -2, -3, -4, -5, -6, -7, -8], [2, 4, 4, 4, 4, 4, 4, 4, 2] + ) + verify(mod, (golden_data, golden_output)) # Try uint8 out_dtype mod = get_mod( @@ -208,74 +212,76 @@ def test_downscale(): def test_upscale(): for compute_dtype in compute_dtypes: for rounding in roundings: - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=2, - output_scale=1, - rounding=rounding, - compute_dtype=compute_dtype, - ) + for qnn_out_dtype in out_dtypes: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=2, + output_scale=1, + rounding=rounding, + compute_dtype=compute_dtype, + ) - # Try positive values - # 8 corresponds to 0.5, resulting in 1 - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.multiply(2, golden_data) - verify(mod, (golden_data, golden_output)) + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.multiply(2, golden_data) + verify(mod, (golden_data, golden_output)) - # Try negative values - # -8 corresponds to -0.5. For UPWARD, this is 0 - golden_data = np.arange(0, -32, -1).astype("int32") - golden_output = np.multiply(2, golden_data) - verify(mod, (golden_data, golden_output)) + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype("int32") + golden_output = np.multiply(2, golden_data) + verify(mod, (golden_data, golden_output)) def test_non_power_of_two(): for compute_dtype in compute_dtypes: for rounding in roundings: - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=3, - rounding=rounding, - compute_dtype=compute_dtype, - ) + for qnn_out_dtype in out_dtypes: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=1, + output_scale=3, + rounding=rounding, + compute_dtype=compute_dtype, + ) - # Try positive values - golden_data = np.multiply(np.arange(0, 32, 1).astype("int32"), 3) - golden_output = np.arange(0, 32, 1) - verify(mod, (golden_data, golden_output)) + # Try positive values + golden_data = np.multiply(np.arange(0, 32, 1).astype("int32"), 3) + golden_output = np.arange(0, 32, 1) + verify(mod, (golden_data, golden_output)) - # Try negative values - golden_data = np.multiply(np.arange(0, -32, -1).astype("int32"), 3) - golden_output = np.arange(0, -32, -1) - verify(mod, (golden_data, golden_output)) + # Try negative values + golden_data = np.multiply(np.arange(0, -32, -1).astype("int32"), 3) + golden_output = np.arange(0, -32, -1) + verify(mod, (golden_data, golden_output)) - # Try a different scale - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=3, - output_scale=1, - rounding=rounding, - ) + # Try a different scale + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=3, + output_scale=1, + rounding=rounding, + ) - # Try positive values - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.multiply(golden_data, 3) - verify(mod, (golden_data, golden_output)) + # Try positive values + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.multiply(golden_data, 3) + verify(mod, (golden_data, golden_output)) - # Try negative values - golden_data = np.arange(0, -32, -1).astype("int32") - golden_output = np.multiply(golden_data, 3) - verify(mod, (golden_data, golden_output)) + # Try negative values + golden_data = np.arange(0, -32, -1).astype("int32") + golden_output = np.multiply(golden_data, 3) + verify(mod, (golden_data, golden_output)) -def test_saturation(): +def test_saturation_int8(): for compute_dtype in compute_dtypes: for rounding in roundings: mod = get_mod( @@ -322,6 +328,70 @@ def test_saturation(): verify(mod, (golden_data, golden_output)) +def test_saturation_int16(): + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(16,), + data_dtype="int32", + out_dtype="int16", + input_scale=0.5, + output_scale=0.5, + rounding=rounding, + compute_dtype=compute_dtype, + ) + golden_data = np.arange(0, 16, 1).astype("int32") + golden_data = np.add(32760, golden_data) + output = np.array( + [ + 32760, + 32761, + 32762, + 32763, + 32764, + 32765, + 32766, + 32767, + 32767, + 32767, + 32767, + 32767, + 32767, + 32767, + 32767, + 32767, + ] + ) + golden_output = output + verify(mod, (golden_data, golden_output)) + + # Try negative numbers + golden_data = np.arange(0, -16, -1).astype("int32") + golden_data = np.add(-32760, golden_data) + output = np.array( + [ + -32760, + -32761, + -32762, + -32763, + -32764, + -32765, + -32766, + -32767, + -32768, + -32768, + -32768, + -32768, + -32768, + -32768, + -32768, + -32768, + ] + ) + golden_output = output + verify(mod, (golden_data, golden_output)) + + def test_zero_point(): # Output zero point for compute_dtype in compute_dtypes: @@ -357,31 +427,32 @@ def test_zero_point(): # Input zero point for compute_dtype in compute_dtypes: for rounding in roundings: - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=16, - input_zero_point=16, - rounding=rounding, - compute_dtype=compute_dtype, - ) + for qnn_out_dtype in out_dtypes: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=1, + output_scale=16, + input_zero_point=16, + rounding=rounding, + compute_dtype=compute_dtype, + ) - # Try positive values - golden_data = np.arange(32, 64, 1).astype("int32") - golden_output = np.repeat([2, 3, 4], [8, 16, 8]) - golden_output = np.subtract(golden_output, 1) - verify(mod, (golden_data, golden_output)) + # Try positive values + golden_data = np.arange(32, 64, 1).astype("int32") + golden_output = np.repeat([2, 3, 4], [8, 16, 8]) + golden_output = np.subtract(golden_output, 1) + verify(mod, (golden_data, golden_output)) - # Try negative values - golden_data = np.arange(-32, -64, -1).astype("int32") - if rounding == "UPWARD": - golden_output = np.repeat([-2, -3, -4], [9, 16, 7]) - else: - golden_output = np.repeat([-2, -3, -4], [8, 16, 8]) - golden_output = np.subtract(golden_output, 1) - verify(mod, (golden_data, golden_output)) + # Try negative values + golden_data = np.arange(-32, -64, -1).astype("int32") + if rounding == "UPWARD": + golden_output = np.repeat([-2, -3, -4], [9, 16, 7]) + else: + golden_output = np.repeat([-2, -3, -4], [8, 16, 8]) + golden_output = np.subtract(golden_output, 1) + verify(mod, (golden_data, golden_output)) def test_per_channel_same_scale(): @@ -390,17 +461,18 @@ def test_per_channel_same_scale(): golden_output = golden_data for compute_dtype in compute_dtypes: for rounding in roundings: - mod = get_mod( - data_shape=(5, 2), - data_dtype="int32", - out_dtype="int8", - input_scale=[0.5, 0.5], - output_scale=0.5, - axis=1, - rounding=rounding, - compute_dtype=compute_dtype, - ) - verify(mod, (golden_data, golden_output)) + for qnn_out_dtype in out_dtypes: + mod = get_mod( + data_shape=(5, 2), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=[0.5, 0.5], + output_scale=0.5, + axis=1, + rounding=rounding, + compute_dtype=compute_dtype, + ) + verify(mod, (golden_data, golden_output)) # Change axis golden_data = np.arange(-10, 10, 1).astype("int32").reshape((2, 2, 5)) @@ -480,88 +552,93 @@ def test_per_channel_different_scale(): def test_default_cfg_and_no_args(): - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=16, - ) - golden_data = np.arange(0, -32, -1).astype("int32") - golden_output = np.repeat([0, -1, -2], [9, 16, 7]) - verify(mod, (golden_data, golden_output)) + for qnn_out_dtype in out_dtypes: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=1, + output_scale=16, + ) + golden_data = np.arange(0, -32, -1).astype("int32") + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + verify(mod, (golden_data, golden_output)) def test_non_default_cfg_and_no_args(): for rounding_cfg in roundings: - with relay.qnn.op.requantize_config(rounding=rounding_cfg): - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=16, - ) + for qnn_out_dtype in out_dtypes: + with relay.qnn.op.requantize_config(rounding=rounding_cfg): + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=1, + output_scale=16, + ) - golden_data = np.arange(0, -32, -1).astype("int32") + golden_data = np.arange(0, -32, -1).astype("int32") - if rounding_cfg == "UPWARD": - golden_output = np.repeat([0, -1, -2], [9, 16, 7]) - else: - golden_output = np.repeat([0, -1, -2], [8, 16, 8]) - verify(mod, (golden_data, golden_output)) + if rounding_cfg == "UPWARD": + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + else: + golden_output = np.repeat([0, -1, -2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) def test_default_cfg_and_args(): for rounding in roundings: - with relay.qnn.op.requantize_config(rounding="UPWARD"): - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=16, - rounding=rounding, - ) - - golden_data = np.arange(0, -32, -1).astype("int32") - - if rounding == "UPWARD": - golden_output = np.repeat([0, -1, -2], [9, 16, 7]) - else: - golden_output = np.repeat([0, -1, -2], [8, 16, 8]) - verify(mod, (golden_data, golden_output)) - - -def test_non_default_cfg_and_args(): - for rounding_arg in roundings: - for rounding_cfg in roundings: - with relay.qnn.op.requantize_config(rounding=rounding_cfg): + for qnn_out_dtype in out_dtypes: + with relay.qnn.op.requantize_config(rounding="UPWARD"): mod = get_mod( data_shape=(32,), data_dtype="int32", - out_dtype="int8", + out_dtype=qnn_out_dtype, input_scale=1, output_scale=16, - rounding=rounding_arg, + rounding=rounding, ) golden_data = np.arange(0, -32, -1).astype("int32") - if rounding_arg == "UPWARD": + if rounding == "UPWARD": golden_output = np.repeat([0, -1, -2], [9, 16, 7]) else: golden_output = np.repeat([0, -1, -2], [8, 16, 8]) verify(mod, (golden_data, golden_output)) +def test_non_default_cfg_and_args(): + for rounding_arg in roundings: + for rounding_cfg in roundings: + for qnn_out_dtype in out_dtypes: + with relay.qnn.op.requantize_config(rounding=rounding_cfg): + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype=qnn_out_dtype, + input_scale=1, + output_scale=16, + rounding=rounding_arg, + ) + + golden_data = np.arange(0, -32, -1).astype("int32") + + if rounding_arg == "UPWARD": + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + else: + golden_output = np.repeat([0, -1, -2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) + + if __name__ == "__main__": test_same_scale() test_scalar_same_scale() test_downscale() test_upscale() test_non_power_of_two() - test_saturation() + test_saturation_int8() + test_saturation_int16() test_zero_point() test_per_channel_same_scale() test_per_channel_different_scale()