diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc index f7ec8bbc6dc..93b13e432f6 100644 --- a/onnx/defs/math/defs.cc +++ b/onnx/defs/math/defs.cc @@ -1294,25 +1294,30 @@ ONNX_OPERATOR_SET_SCHEMA( ctx) { auto a_type = ctx.getInputType(0); auto b_type = ctx.getInputType(3); - auto y_type = ctx.getOutputType(0); - if (nullptr == a_type || nullptr == b_type || nullptr == y_type || + if (nullptr == a_type || nullptr == b_type || a_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType || b_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType) { + fail_type_inference("inputs are expected to have tensor type."); + } + + auto a_zero_point_type = ctx.getInputType(2); + if (nullptr == a_zero_point_type || + a_zero_point_type->tensor_type().elem_type() != + a_type->tensor_type().elem_type()) { fail_type_inference( - "inputs are expected to have tensor type and output type should not be null."); + "input and zero_point pair is expected to have be same type."); } - if (ONNX_NAMESPACE::TensorProto::UINT8 == - a_type->tensor_type().elem_type() && - ONNX_NAMESPACE::TensorProto::UINT8 == + auto b_zero_point_type = ctx.getInputType(5); + if (nullptr == b_zero_point_type || + b_zero_point_type->tensor_type().elem_type() != b_type->tensor_type().elem_type()) { - y_type->mutable_tensor_type()->set_elem_type( - ONNX_NAMESPACE::TensorProto::UINT8); - } else { - y_type->mutable_tensor_type()->set_elem_type( - ONNX_NAMESPACE::TensorProto::INT8); + fail_type_inference( + "input and zero_point pair is expected to have same type."); } + propagateElemTypeFromInputToOutput(ctx, 7, 0); + matmulShapeInference(ctx, 0, 3); })); diff --git a/onnx/defs/nn/defs.cc b/onnx/defs/nn/defs.cc index fc6290d4121..eb6bb184c13 100644 --- a/onnx/defs/nn/defs.cc +++ b/onnx/defs/nn/defs.cc @@ -32,7 +32,6 @@ void convPoolShapeInference( bool require_kernel_shape, int input1Idx, int input2Idx) { - // we need the first input shape for this inference. if (!hasInputShape(ctx, input1Idx)) { return; @@ -115,7 +114,7 @@ void convPoolShapeInference( *output_shape->add_dim() = input_shape.dim(1); } else { *output_shape->add_dim() = input_shape.dim(0); - auto& second_input_shape = getInputShape(ctx, 1); + auto& second_input_shape = getInputShape(ctx, input2Idx); if (second_input_shape.dim_size() < 1) { fail_shape_inference("Second input tensor has wrong dimension"); } @@ -1026,25 +1025,30 @@ ONNX_OPERATOR_SET_SCHEMA( ctx) { auto x_type = ctx.getInputType(0); auto w_type = ctx.getInputType(3); - auto y_type = ctx.getOutputType(0); - if (nullptr == x_type || nullptr == w_type || nullptr == y_type || + if (nullptr == x_type || nullptr == w_type || x_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType || w_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType) { + fail_type_inference("inputs are expected to have tensor type."); + } + + auto x_zero_point_type = ctx.getInputType(2); + if (nullptr == x_zero_point_type || + x_zero_point_type->tensor_type().elem_type() != + x_type->tensor_type().elem_type()) { fail_type_inference( - "inputs are expected to have tensor type and output type should not be null."); + "input and zero_point pair is expected to have be same type."); } - if (ONNX_NAMESPACE::TensorProto::UINT8 == - x_type->tensor_type().elem_type() && - ONNX_NAMESPACE::TensorProto::UINT8 == + auto w_zero_point_type = ctx.getInputType(5); + if (nullptr == w_zero_point_type || + w_zero_point_type->tensor_type().elem_type() != w_type->tensor_type().elem_type()) { - y_type->mutable_tensor_type()->set_elem_type( - ONNX_NAMESPACE::TensorProto::UINT8); - } else { - y_type->mutable_tensor_type()->set_elem_type( - ONNX_NAMESPACE::TensorProto::INT8); + fail_type_inference( + "weight and zero_point pair is expected to have same type."); } + propagateElemTypeFromInputToOutput(ctx, 7, 0); + convPoolShapeInference(ctx, true, false, 0, 3); })); diff --git a/onnx/test/shape_inference_test.py b/onnx/test/shape_inference_test.py index 91f9371b4ea..fb42edd51ec 100644 --- a/onnx/test/shape_inference_test.py +++ b/onnx/test/shape_inference_test.py @@ -1421,6 +1421,252 @@ def test_constantofshape_without_input_shape(self): # type: () -> None self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.UINT8, (None, None, None))]) # type: ignore + def test_convinteger(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.UINT8, (3, 4, 5, 6, 7)), + ('y', TensorProto.UINT8, (5, 4, 2, 4, 3))], + [make_node('ConvInteger', ['x', 'y'], 'z', pads=[0, 1, 1, 0, 0, 1], dilations=[1, 2, 2], strides=[1, 1, 2])], + []) + self._assert_inferred(graph, [make_tensor_value_info('z', TensorProto.INT32, (3, 5, 4, 1, 3))]) + + def test_convinetger_dilations(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.UINT8, (30, 4, 8, 8, 8)), + ('y', TensorProto.INT8, (50, 4, 3, 3, 3)), + ('x_zero_point', TensorProto.UINT8, ()), + ('y_zero_point', TensorProto.UINT8, ())], + [make_node('ConvInteger', ['x', 'y', 'x_zero_point', 'y_zero_point'], 'z', dilations=[1, 2, 3])], + []) + self._assert_inferred(graph, [make_tensor_value_info('z', TensorProto.INT32, (30, 50, 6, 4, 2))]) + + def test_convinteger_strides(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.INT8, (30, 4, 8, 8, 8)), + ('y', TensorProto.INT8, (50, 4, 3, 3, 3)), + ('x_zero_point', TensorProto.UINT8, ()), + ('y_zero_point', TensorProto.UINT8, ())], + [make_node('ConvInteger', ['x', 'y', 'x_zero_point', 'y_zero_point'], 'z', strides=[1, 2, 3])], + []) + self._assert_inferred(graph, [make_tensor_value_info('z', TensorProto.INT32, (30, 50, 6, 3, 2))]) + + def test_convineteger_pads(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.UINT8, (30, 4, 7, 6, 4)), + ('y', TensorProto.INT8, (50, 4, 3, 3, 3))], + [make_node('ConvInteger', ['x', 'y'], 'z', pads=[1, 1, 2, 0, 1, 2])], + []) + self._assert_inferred(graph, [make_tensor_value_info('z', TensorProto.INT32, (30, 50, 6, 6, 6))]) + + def test_convineteger_group(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.INT8, (30, 4, 8, 8, 8)), + ('y', TensorProto.INT8, (4, 1, 8, 8, 8))], + [make_node('ConvInteger', ['x', 'y'], 'z', group=4)], + []) + self._assert_inferred(graph, [make_tensor_value_info('z', TensorProto.INT32, (30, 4, 1, 1, 1))]) + + def test_convineteger_partial_missing_shape(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.UINT8, (30, 4, None, 6, 4)), + ('y', TensorProto.UINT8, (50, 4, 3, 3, 3)), + ('x_zero_point', TensorProto.UINT8, ()), + ('y_zero_point', TensorProto.UINT8, ())], + [make_node('ConvInteger', ['x', 'y', 'x_zero_point', 'y_zero_point'], 'z', pads=[1, 1, 2, 0, 1, 2])], + []) + self._assert_inferred(graph, [make_tensor_value_info('z', TensorProto.INT32, (30, 50, None, 6, 6))]) # type: ignore + + def test_convineteger_partial_missing_weight_shape(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.UINT8, (30, 4, 7, 6, 4)), + ('y', TensorProto.UINT8, (50, 4, None, 3, 3))], + [make_node('ConvInteger', ['x', 'y'], 'z', pads=[1, 1, 2, 0, 1, 2])], + []) + self._assert_inferred(graph, [make_tensor_value_info('z', TensorProto.INT32, None)]) + + def test_qlinearconv(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.UINT8, (3, 4, 5, 6, 7)), + ('x_scale', TensorProto.FLOAT, ()), + ('x_zero_point', TensorProto.UINT8, ()), + ('w', TensorProto.UINT8, (5, 4, 2, 4, 3)), + ('w_scale', TensorProto.FLOAT, ()), + ('w_zero_point', TensorProto.UINT8, ()), + ('y_scale', TensorProto.FLOAT, ()), + ('y_zero_point', TensorProto.UINT8, ())], + [make_node('QLinearConv', ['x', 'x_scale', 'x_zero_point', 'w', 'w_scale', 'w_zero_point', 'y_scale', 'y_zero_point'], 'y', pads=[0, 1, 1, 0, 0, 1], dilations=[1, 2, 2], strides=[1, 1, 2])], + []) + self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.UINT8, (3, 5, 4, 1, 3))]) + + def test_qlinearconv_dilations(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.UINT8, (30, 4, 8, 8, 8)), + ('x_scale', TensorProto.FLOAT, ()), + ('x_zero_point', TensorProto.UINT8, ()), + ('w', TensorProto.UINT8, (50, 4, 3, 3, 3)), + ('w_scale', TensorProto.FLOAT, ()), + ('w_zero_point', TensorProto.UINT8, ()), + ('y_scale', TensorProto.FLOAT, ()), + ('y_zero_point', TensorProto.UINT8, ())], + [make_node('QLinearConv', ['x', 'x_scale', 'x_zero_point', 'w', 'w_scale', 'w_zero_point', 'y_scale', 'y_zero_point'], 'y', dilations=[1, 2, 3])], + []) + self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.UINT8, (30, 50, 6, 4, 2))]) + + def test_qlinearconv_strides(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.INT8, (30, 4, 8, 8, 8)), + ('x_scale', TensorProto.FLOAT, ()), + ('x_zero_point', TensorProto.INT8, ()), + ('w', TensorProto.INT8, (50, 4, 3, 3, 3)), + ('w_scale', TensorProto.FLOAT, ()), + ('w_zero_point', TensorProto.INT8, ()), + ('y_scale', TensorProto.FLOAT, ()), + ('y_zero_point', TensorProto.INT8, ())], + [make_node('QLinearConv', ['x', 'x_scale', 'x_zero_point', 'w', 'w_scale', 'w_zero_point', 'y_scale', 'y_zero_point'], 'y', strides=[1, 2, 3])], + []) + self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.INT8, (30, 50, 6, 3, 2))]) + + def test_qlinearconv_pads(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.UINT8, (30, 4, 7, 6, 4)), + ('x_scale', TensorProto.FLOAT, ()), + ('x_zero_point', TensorProto.UINT8, ()), + ('w', TensorProto.INT8, (50, 4, 3, 3, 3)), + ('w_scale', TensorProto.FLOAT, ()), + ('w_zero_point', TensorProto.INT8, ()), + ('y_scale', TensorProto.FLOAT, ()), + ('y_zero_point', TensorProto.UINT8, ())], + [make_node('QLinearConv', ['x', 'x_scale', 'x_zero_point', 'w', 'w_scale', 'w_zero_point', 'y_scale', 'y_zero_point'], 'y', pads=[1, 1, 2, 0, 1, 2])], + []) + self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.UINT8, (30, 50, 6, 6, 6))]) + + def test_qlinearconv_group(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.INT8, (30, 4, 8, 8, 8)), + ('x_scale', TensorProto.FLOAT, ()), + ('x_zero_point', TensorProto.INT8, ()), + ('w', TensorProto.INT8, (4, 1, 8, 8, 8)), + ('w_scale', TensorProto.FLOAT, ()), + ('w_zero_point', TensorProto.INT8, ()), + ('y_scale', TensorProto.FLOAT, ()), + ('y_zero_point', TensorProto.INT8, ())], + [make_node('QLinearConv', ['x', 'x_scale', 'x_zero_point', 'w', 'w_scale', 'w_zero_point', 'y_scale', 'y_zero_point'], 'y', group=4)], + []) + self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.INT8, (30, 4, 1, 1, 1))]) + + def test_qlinearconv_partial_missing_shape(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.UINT8, (30, 4, None, 6, 4)), + ('x_scale', TensorProto.FLOAT, ()), + ('x_zero_point', TensorProto.UINT8, ()), + ('w', TensorProto.UINT8, (50, 4, 3, 3, 3)), + ('w_scale', TensorProto.FLOAT, ()), + ('w_zero_point', TensorProto.UINT8, ()), + ('y_scale', TensorProto.FLOAT, ()), + ('y_zero_point', TensorProto.UINT8, ())], + [make_node('QLinearConv', ['x', 'x_scale', 'x_zero_point', 'w', 'w_scale', 'w_zero_point', 'y_scale', 'y_zero_point'], 'y', pads=[1, 1, 2, 0, 1, 2])], + []) + self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.UINT8, (30, 50, None, 6, 6))]) # type: ignore + + def test_qlinearconv_partial_missing_weight_shape(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.UINT8, (30, 4, 7, 6, 4)), + ('x_scale', TensorProto.FLOAT, ()), + ('x_zero_point', TensorProto.UINT8, ()), + ('w', TensorProto.UINT8, (50, 4, None, 3, 3)), + ('w_scale', TensorProto.FLOAT, ()), + ('w_zero_point', TensorProto.UINT8, ()), + ('y_scale', TensorProto.FLOAT, ()), + ('y_zero_point', TensorProto.UINT8, ())], + [make_node('QLinearConv', ['x', 'x_scale', 'x_zero_point', 'w', 'w_scale', 'w_zero_point', 'y_scale', 'y_zero_point'], 'y', pads=[1, 1, 2, 0, 1, 2])], + []) + self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.UINT8, None)]) + + def _make_qlinearmatmul_test(self, shape1, shape2): # type: (Sequence[int], Sequence[int]) -> None + expected_out_shape = np.matmul(np.arange(np.product(shape1)).reshape(shape1), + np.arange(np.product(shape2)).reshape(shape2)).shape + graph = self._make_graph( + [('a', TensorProto.UINT8, shape1), + ('a_scale', TensorProto.FLOAT, ()), + ('a_zero_point', TensorProto.UINT8, ()), + ('b', TensorProto.UINT8, shape2), + ('b_scale', TensorProto.FLOAT, ()), + ('b_zero_point', TensorProto.UINT8, ()), + ('y_scale', TensorProto.FLOAT, ()), + ('y_zero_point', TensorProto.UINT8, ())], + [make_node('QLinearMatMul', ['a', 'a_scale', 'a_zero_point', 'b', 'b_scale', 'b_zero_point', 'y_scale', 'y_zero_point'], ['y'])], + []) + self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.UINT8, expected_out_shape)]) + + def test_qlinearmatmul(self): # type: () -> None + self._make_qlinearmatmul_test((3,), (3,)) + self._make_qlinearmatmul_test((4, 2), (2, 4)) + self._make_qlinearmatmul_test((2,), (2, 3)) + self._make_qlinearmatmul_test((4, 2), (2,)) + self._make_qlinearmatmul_test((5, 1, 4, 2), (1, 3, 2, 3)) + self._make_qlinearmatmul_test((4, 2), (3, 2, 3)) + + def _make_qlinearmatmul_test_allow_unknown(self, shape1, shape2, expected_out_shape): # type: (Any, Any, Any) -> None + graph = self._make_graph( + [('a', TensorProto.UINT8, shape1), + ('a_scale', TensorProto.FLOAT, ()), + ('a_zero_point', TensorProto.UINT8, ()), + ('b', TensorProto.UINT8, shape2), + ('b_scale', TensorProto.FLOAT, ()), + ('b_zero_point', TensorProto.UINT8, ()), + ('y_scale', TensorProto.FLOAT, ()), + ('y_zero_point', TensorProto.UINT8, ())], + [make_node('QLinearMatMul', ['a', 'a_scale', 'a_zero_point', 'b', 'b_scale', 'b_zero_point', 'y_scale', 'y_zero_point'], ['y'])], + []) + self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.UINT8, expected_out_shape)]) + + def test_qlinearmatmul_allow_unknown(self): # type: () -> None + self._make_qlinearmatmul_test_allow_unknown((None,), (None,), ()) + self._make_qlinearmatmul_test_allow_unknown((3,), (None,), ()) + self._make_qlinearmatmul_test_allow_unknown((2,), (2, "a"), ("a",)) + self._make_qlinearmatmul_test_allow_unknown((4, 2), (2, "a"), (4, "a")) + self._make_qlinearmatmul_test_allow_unknown((4, None), (2, "a"), (4, "a")) + self._make_qlinearmatmul_test_allow_unknown((4, None), (None, "a"), (4, "a")) + self._make_qlinearmatmul_test_allow_unknown((1, 4, 2), ("a", 2, 5), ("a", 4, 5)) + self._make_qlinearmatmul_test_allow_unknown((1, 3, 4, 2), ("a", 2, 5), (1, 3, 4, 5)) + + def _make_matmulinteger_test(self, shape1, shape2): # type: (Sequence[int], Sequence[int]) -> None + expected_out_shape = np.matmul(np.arange(np.product(shape1)).reshape(shape1), + np.arange(np.product(shape2)).reshape(shape2)).shape + graph = self._make_graph( + [('A', TensorProto.UINT8, shape1), + ('B', TensorProto.UINT8, shape2), + ('a_zero_point', TensorProto.UINT8, ()), + ('b_zero_point', TensorProto.UINT8, ())], + [make_node('MatMulInteger', ['A', 'B', 'a_zero_point', 'b_zero_point'], ['Y'])], + []) + self._assert_inferred(graph, [make_tensor_value_info('Y', TensorProto.INT32, expected_out_shape)]) + + def test_matmulinteger(self): # type: () -> None + self._make_matmulinteger_test((2,), (2,)) + self._make_matmulinteger_test((1, 2), (2, 3)) + self._make_matmulinteger_test((2,), (2, 3)) + self._make_matmulinteger_test((4, 2), (2,)) + self._make_matmulinteger_test((5, 1, 4, 2), (1, 3, 2, 3)) + self._make_matmulinteger_test((4, 2), (3, 2, 3)) + + def test_quantizelinear(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.FLOAT, (30, 4, 5)), + ('y_scale', TensorProto.FLOAT, ()), + ('y_zero_point', TensorProto.UINT8, ())], + [make_node('QuantizeLinear', ['x', 'y_scale', 'y_zero_point'], ['y'])], + []) + self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.UINT8, (30, 4, 5))]) + + def test_dequantizelinear(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.UINT8, (30, 4, 5)), + ('x_scale', TensorProto.FLOAT, ()), + ('x_zero_point', TensorProto.UINT8, ())], + [make_node('DequantizeLinear', ['x', 'x_scale', 'x_zero_point'], ['y'])], + []) + self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, (30, 4, 5))]) + if __name__ == '__main__': unittest.main()