Skip to content

Commit

Permalink
[ONNX][apache#8838] QLinearSigmoid contrib op and Bug Fix for Dequant…
Browse files Browse the repository at this point in the history
…izeLinear (apache#9028)

* [ONNX][apache#8838] QLinearSigmoid contrib op and Bug Fix for DequantizeLinear

* [ONNX][apache#8838] QLinearSigmoid contrib op and Bug Fix for DequantizeLinear

* [ONNX][apache#8838] QLinearSigmoid contrib op and Bug Fix for DequantizeLinear

* [ONNX][apache#8838] QLinearSigmoid contrib op and Bug Fix for DequantizeLinear
  • Loading branch information
arangasa authored and ylc committed Jan 13, 2022
1 parent 0b5fece commit 2c69cae
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
25 changes: 25 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3254,6 +3254,8 @@ def _impl_v10(cls, inputs, attr, params):
def _impl_v13(cls, inputs, attr, params):
data, scale, zp = inputs
axis = attr.get("axis", 1)
if len(infer_shape(data)) <= 1:
axis = 0
return _qnn.op.dequantize(data, scale, _op.cast(zp, "int32"), axis)


Expand Down Expand Up @@ -3428,6 +3430,28 @@ def _impl_v10(cls, inputs, attr, params):
return _qnn.op.quantize(out, y_scale, y_zero_point, out_dtype=dtype)


class QLinearSigmoid(OnnxOpConverter):
"""Operator converter for QLinearSigmoid from Microsoft onnxruntime contrib opset."""

@classmethod
def _impl_v10(cls, inputs, attr, params):
x = inputs[0]
x_scale = get_scalar(inputs[1], params)
x_zero_point = get_scalar(inputs[2], params, "int32")
y_scale = fold_constant(get_scalar(inputs[3], params))
y_zero_point = get_scalar(inputs[4], params, "int32")

dtype = infer_type(x).checked_type.dtype

## Apparently, onnxruntime doesn't do this op in integer, they dequantize to fp32
## and then requantize after:
## https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/
## providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp#L245
x = _qnn.op.dequantize(x, x_scale, x_zero_point)
out = _op.sigmoid(x)
return _qnn.op.quantize(out, y_scale, y_zero_point, out_dtype=dtype)


class QLinearConcat(OnnxOpConverter):
"""Operator converter for QLinearConcat from Microsoft onnxruntime contrib opset."""

Expand Down Expand Up @@ -4084,6 +4108,7 @@ def _get_convert_map(opset):
"QLinearConcat": QLinearConcat.get_converter(opset),
"QLinearAdd": QLinearAdd.get_converter(opset),
"QLinearMul": QLinearMul.get_converter(opset),
"QLinearSigmoid": QLinearSigmoid.get_converter(opset),
"ConvInteger": ConvInteger.get_converter(opset),
"QLinearAveragePool": QLinearAveragePool.get_converter(opset),
"QLinearGlobalAveragePool": QLinearGlobalAveragePool.get_converter(opset),
Expand Down
27 changes: 27 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5542,11 +5542,38 @@ def verify_qlinearmul(a_shape, b_shape, c_shape):
model = helper.make_model(graph, producer_name="qlinearmul_test")
quantize_and_verify_with_ort(model, input_names, [a_shape, b_shape], target, dev)

verify_qlinearmul([7], [7], [7])
verify_qlinearmul([4, 2], [4, 2], [4, 2])
verify_qlinearmul([4, 2], [2], [4, 2])
verify_qlinearmul([5, 1, 7], [2, 7], [5, 2, 7])


@tvm.testing.parametrize_targets
def test_qlinearsigmoid(target, dev):
def verify_qlinearsigmoid(a_shape):

a_array = np.random.random(a_shape).astype("float32")

input_nodes = [helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape))]

input_values = [a_array]

node = helper.make_node("Sigmoid", ["a"], ["B"])
graph = helper.make_graph(
[node],
"qlinearsigmoid_test",
inputs=input_nodes,
outputs=[helper.make_tensor_value_info("B", TensorProto.FLOAT, list(a_shape))],
)
model = helper.make_model(graph, producer_name="qlinearsigmoid_test")
quantize_and_verify_with_ort(model, ["a"], [a_shape], target, dev)

verify_qlinearsigmoid([4, 2])
verify_qlinearsigmoid([5])
verify_qlinearsigmoid([3, 4, 5])
verify_qlinearsigmoid([])


@tvm.testing.parametrize_targets
def test_random_uniform(target, dev):
def get_random_uniform(shape, dtype="float32", high=1.0, low=0.0, seed=None):
Expand Down

0 comments on commit 2c69cae

Please sign in to comment.