Skip to content

Commit 4dc7f79

Browse files
shoubhikwweic
authored andcommitted
Adding support for dequantizing from int32 to float32. (apache#4130)
1 parent 118acb2 commit 4dc7f79

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

src/relay/qnn/op/dequantize.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ bool DequantizeRel(const Array<Type>& types,
4343
CHECK_EQ(types.size(), 2);
4444
const auto* data = types[0].as<TensorTypeNode>();
4545
const auto input_dtype = data->dtype;
46-
CHECK(input_dtype == Int(8) || input_dtype == UInt(8))
47-
<< "Input type should be one of the quantized types [unit8, int8] but was " << input_dtype;
46+
CHECK(input_dtype == Int(8) || input_dtype == UInt(8) || input_dtype == Int(32))
47+
<< "Input type should be one of the quantized types [unit8, int8, int32] but was "
48+
<< input_dtype;
4849
const Array<tvm::Expr> oshape = data->shape;
4950
// assign output type, output will always be float 32.
5051
reporter->Assign(types[1], TensorTypeNode::make(oshape, Float(32)));

tests/python/relay/test_op_qnn_dequantize.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,27 +44,35 @@ def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data):
4444
def test_uint8_to_float32():
4545
data = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \
4646
.astype('uint8') \
47-
.reshape((2,5))
47+
.reshape((2, 5))
4848
output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
4949
.astype('float32') \
50-
.reshape((2,5))
50+
.reshape((2, 5))
5151
quant_args = {"in_zero_point":127, "in_scale":0.5}
5252
quantize_test_driver(in_dtype='uint8', quant_args=quant_args, in_data=data,
5353
verify_output_data=output)
5454

5555
def test_int8_to_float32():
5656
data = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \
5757
.astype('int8') \
58-
.reshape((2,5))
58+
.reshape((2, 5))
5959
output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
6060
.astype('float32') \
61-
.reshape((2,5))
62-
quant_args = {"in_zero_point":-1, "in_scale":0.5}
61+
.reshape((2, 5))
62+
quant_args = {"in_zero_point": -1, "in_scale": 0.5}
6363
quantize_test_driver(in_dtype='int8', quant_args=quant_args, in_data=data,
6464
verify_output_data=output)
6565

66+
def test_int32_to_float32():
67+
data = np.array([113, 29, -1052]).astype('int32')
68+
output = np.array([0.6550452, 0.16810896, -6.098297]).astype('float32')
69+
quant_args = {"in_zero_point": 0, "in_scale": 0.0057968604}
70+
quantize_test_driver(in_dtype='int32', quant_args=quant_args, in_data=data,
71+
verify_output_data=output)
72+
6673
test_uint8_to_float32()
6774
test_int8_to_float32()
75+
test_int32_to_float32()
6876

6977

7078
if __name__ == "__main__":

0 commit comments

Comments
 (0)