Skip to content

Commit

Permalink
Shape Inference Tests for QOps (onnx#1929)
Browse files Browse the repository at this point in the history
* fix shape inference and add tests for shape inference

* cosmetic fixes

* plus some formatting
  • Loading branch information
askhade authored and houseroad committed Apr 15, 2019
1 parent a80c337 commit 3717dc6
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 24 deletions.
27 changes: 16 additions & 11 deletions onnx/defs/math/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1294,25 +1294,30 @@ ONNX_OPERATOR_SET_SCHEMA(
ctx) {
auto a_type = ctx.getInputType(0);
auto b_type = ctx.getInputType(3);
auto y_type = ctx.getOutputType(0);
if (nullptr == a_type || nullptr == b_type || nullptr == y_type ||
if (nullptr == a_type || nullptr == b_type ||
a_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType ||
b_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType) {
fail_type_inference("inputs are expected to have tensor type.");
}

auto a_zero_point_type = ctx.getInputType(2);
if (nullptr == a_zero_point_type ||
a_zero_point_type->tensor_type().elem_type() !=
a_type->tensor_type().elem_type()) {
fail_type_inference(
"inputs are expected to have tensor type and output type should not be null.");
"input and zero_point pair is expected to have be same type.");
}

if (ONNX_NAMESPACE::TensorProto::UINT8 ==
a_type->tensor_type().elem_type() &&
ONNX_NAMESPACE::TensorProto::UINT8 ==
auto b_zero_point_type = ctx.getInputType(5);
if (nullptr == b_zero_point_type ||
b_zero_point_type->tensor_type().elem_type() !=
b_type->tensor_type().elem_type()) {
y_type->mutable_tensor_type()->set_elem_type(
ONNX_NAMESPACE::TensorProto::UINT8);
} else {
y_type->mutable_tensor_type()->set_elem_type(
ONNX_NAMESPACE::TensorProto::INT8);
fail_type_inference(
"input and zero_point pair is expected to have same type.");
}

propagateElemTypeFromInputToOutput(ctx, 7, 0);

matmulShapeInference(ctx, 0, 3);
}));

Expand Down
30 changes: 17 additions & 13 deletions onnx/defs/nn/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ void convPoolShapeInference(
bool require_kernel_shape,
int input1Idx,
int input2Idx) {

// we need the first input shape for this inference.
if (!hasInputShape(ctx, input1Idx)) {
return;
Expand Down Expand Up @@ -115,7 +114,7 @@ void convPoolShapeInference(
*output_shape->add_dim() = input_shape.dim(1);
} else {
*output_shape->add_dim() = input_shape.dim(0);
auto& second_input_shape = getInputShape(ctx, 1);
auto& second_input_shape = getInputShape(ctx, input2Idx);
if (second_input_shape.dim_size() < 1) {
fail_shape_inference("Second input tensor has wrong dimension");
}
Expand Down Expand Up @@ -1026,25 +1025,30 @@ ONNX_OPERATOR_SET_SCHEMA(
ctx) {
auto x_type = ctx.getInputType(0);
auto w_type = ctx.getInputType(3);
auto y_type = ctx.getOutputType(0);
if (nullptr == x_type || nullptr == w_type || nullptr == y_type ||
if (nullptr == x_type || nullptr == w_type ||
x_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType ||
w_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType) {
fail_type_inference("inputs are expected to have tensor type.");
}

auto x_zero_point_type = ctx.getInputType(2);
if (nullptr == x_zero_point_type ||
x_zero_point_type->tensor_type().elem_type() !=
x_type->tensor_type().elem_type()) {
fail_type_inference(
"inputs are expected to have tensor type and output type should not be null.");
"input and zero_point pair is expected to have be same type.");
}

if (ONNX_NAMESPACE::TensorProto::UINT8 ==
x_type->tensor_type().elem_type() &&
ONNX_NAMESPACE::TensorProto::UINT8 ==
auto w_zero_point_type = ctx.getInputType(5);
if (nullptr == w_zero_point_type ||
w_zero_point_type->tensor_type().elem_type() !=
w_type->tensor_type().elem_type()) {
y_type->mutable_tensor_type()->set_elem_type(
ONNX_NAMESPACE::TensorProto::UINT8);
} else {
y_type->mutable_tensor_type()->set_elem_type(
ONNX_NAMESPACE::TensorProto::INT8);
fail_type_inference(
"weight and zero_point pair is expected to have same type.");
}

propagateElemTypeFromInputToOutput(ctx, 7, 0);

convPoolShapeInference(ctx, true, false, 0, 3);
}));

Expand Down
246 changes: 246 additions & 0 deletions onnx/test/shape_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,252 @@ def test_constantofshape_without_input_shape(self): # type: () -> None
self._assert_inferred(graph,
[make_tensor_value_info('y', TensorProto.UINT8, (None, None, None))]) # type: ignore

def test_convinteger(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.UINT8, (3, 4, 5, 6, 7)),
('y', TensorProto.UINT8, (5, 4, 2, 4, 3))],
[make_node('ConvInteger', ['x', 'y'], 'z', pads=[0, 1, 1, 0, 0, 1], dilations=[1, 2, 2], strides=[1, 1, 2])],
[])
self._assert_inferred(graph, [make_tensor_value_info('z', TensorProto.INT32, (3, 5, 4, 1, 3))])

def test_convinetger_dilations(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.UINT8, (30, 4, 8, 8, 8)),
('y', TensorProto.INT8, (50, 4, 3, 3, 3)),
('x_zero_point', TensorProto.UINT8, ()),
('y_zero_point', TensorProto.UINT8, ())],
[make_node('ConvInteger', ['x', 'y', 'x_zero_point', 'y_zero_point'], 'z', dilations=[1, 2, 3])],
[])
self._assert_inferred(graph, [make_tensor_value_info('z', TensorProto.INT32, (30, 50, 6, 4, 2))])

def test_convinteger_strides(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.INT8, (30, 4, 8, 8, 8)),
('y', TensorProto.INT8, (50, 4, 3, 3, 3)),
('x_zero_point', TensorProto.UINT8, ()),
('y_zero_point', TensorProto.UINT8, ())],
[make_node('ConvInteger', ['x', 'y', 'x_zero_point', 'y_zero_point'], 'z', strides=[1, 2, 3])],
[])
self._assert_inferred(graph, [make_tensor_value_info('z', TensorProto.INT32, (30, 50, 6, 3, 2))])

def test_convineteger_pads(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.UINT8, (30, 4, 7, 6, 4)),
('y', TensorProto.INT8, (50, 4, 3, 3, 3))],
[make_node('ConvInteger', ['x', 'y'], 'z', pads=[1, 1, 2, 0, 1, 2])],
[])
self._assert_inferred(graph, [make_tensor_value_info('z', TensorProto.INT32, (30, 50, 6, 6, 6))])

def test_convineteger_group(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.INT8, (30, 4, 8, 8, 8)),
('y', TensorProto.INT8, (4, 1, 8, 8, 8))],
[make_node('ConvInteger', ['x', 'y'], 'z', group=4)],
[])
self._assert_inferred(graph, [make_tensor_value_info('z', TensorProto.INT32, (30, 4, 1, 1, 1))])

def test_convineteger_partial_missing_shape(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.UINT8, (30, 4, None, 6, 4)),
('y', TensorProto.UINT8, (50, 4, 3, 3, 3)),
('x_zero_point', TensorProto.UINT8, ()),
('y_zero_point', TensorProto.UINT8, ())],
[make_node('ConvInteger', ['x', 'y', 'x_zero_point', 'y_zero_point'], 'z', pads=[1, 1, 2, 0, 1, 2])],
[])
self._assert_inferred(graph, [make_tensor_value_info('z', TensorProto.INT32, (30, 50, None, 6, 6))]) # type: ignore

def test_convineteger_partial_missing_weight_shape(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.UINT8, (30, 4, 7, 6, 4)),
('y', TensorProto.UINT8, (50, 4, None, 3, 3))],
[make_node('ConvInteger', ['x', 'y'], 'z', pads=[1, 1, 2, 0, 1, 2])],
[])
self._assert_inferred(graph, [make_tensor_value_info('z', TensorProto.INT32, None)])

def test_qlinearconv(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.UINT8, (3, 4, 5, 6, 7)),
('x_scale', TensorProto.FLOAT, ()),
('x_zero_point', TensorProto.UINT8, ()),
('w', TensorProto.UINT8, (5, 4, 2, 4, 3)),
('w_scale', TensorProto.FLOAT, ()),
('w_zero_point', TensorProto.UINT8, ()),
('y_scale', TensorProto.FLOAT, ()),
('y_zero_point', TensorProto.UINT8, ())],
[make_node('QLinearConv', ['x', 'x_scale', 'x_zero_point', 'w', 'w_scale', 'w_zero_point', 'y_scale', 'y_zero_point'], 'y', pads=[0, 1, 1, 0, 0, 1], dilations=[1, 2, 2], strides=[1, 1, 2])],
[])
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.UINT8, (3, 5, 4, 1, 3))])

def test_qlinearconv_dilations(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.UINT8, (30, 4, 8, 8, 8)),
('x_scale', TensorProto.FLOAT, ()),
('x_zero_point', TensorProto.UINT8, ()),
('w', TensorProto.UINT8, (50, 4, 3, 3, 3)),
('w_scale', TensorProto.FLOAT, ()),
('w_zero_point', TensorProto.UINT8, ()),
('y_scale', TensorProto.FLOAT, ()),
('y_zero_point', TensorProto.UINT8, ())],
[make_node('QLinearConv', ['x', 'x_scale', 'x_zero_point', 'w', 'w_scale', 'w_zero_point', 'y_scale', 'y_zero_point'], 'y', dilations=[1, 2, 3])],
[])
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.UINT8, (30, 50, 6, 4, 2))])

def test_qlinearconv_strides(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.INT8, (30, 4, 8, 8, 8)),
('x_scale', TensorProto.FLOAT, ()),
('x_zero_point', TensorProto.INT8, ()),
('w', TensorProto.INT8, (50, 4, 3, 3, 3)),
('w_scale', TensorProto.FLOAT, ()),
('w_zero_point', TensorProto.INT8, ()),
('y_scale', TensorProto.FLOAT, ()),
('y_zero_point', TensorProto.INT8, ())],
[make_node('QLinearConv', ['x', 'x_scale', 'x_zero_point', 'w', 'w_scale', 'w_zero_point', 'y_scale', 'y_zero_point'], 'y', strides=[1, 2, 3])],
[])
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.INT8, (30, 50, 6, 3, 2))])

def test_qlinearconv_pads(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.UINT8, (30, 4, 7, 6, 4)),
('x_scale', TensorProto.FLOAT, ()),
('x_zero_point', TensorProto.UINT8, ()),
('w', TensorProto.INT8, (50, 4, 3, 3, 3)),
('w_scale', TensorProto.FLOAT, ()),
('w_zero_point', TensorProto.INT8, ()),
('y_scale', TensorProto.FLOAT, ()),
('y_zero_point', TensorProto.UINT8, ())],
[make_node('QLinearConv', ['x', 'x_scale', 'x_zero_point', 'w', 'w_scale', 'w_zero_point', 'y_scale', 'y_zero_point'], 'y', pads=[1, 1, 2, 0, 1, 2])],
[])
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.UINT8, (30, 50, 6, 6, 6))])

def test_qlinearconv_group(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.INT8, (30, 4, 8, 8, 8)),
('x_scale', TensorProto.FLOAT, ()),
('x_zero_point', TensorProto.INT8, ()),
('w', TensorProto.INT8, (4, 1, 8, 8, 8)),
('w_scale', TensorProto.FLOAT, ()),
('w_zero_point', TensorProto.INT8, ()),
('y_scale', TensorProto.FLOAT, ()),
('y_zero_point', TensorProto.INT8, ())],
[make_node('QLinearConv', ['x', 'x_scale', 'x_zero_point', 'w', 'w_scale', 'w_zero_point', 'y_scale', 'y_zero_point'], 'y', group=4)],
[])
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.INT8, (30, 4, 1, 1, 1))])

def test_qlinearconv_partial_missing_shape(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.UINT8, (30, 4, None, 6, 4)),
('x_scale', TensorProto.FLOAT, ()),
('x_zero_point', TensorProto.UINT8, ()),
('w', TensorProto.UINT8, (50, 4, 3, 3, 3)),
('w_scale', TensorProto.FLOAT, ()),
('w_zero_point', TensorProto.UINT8, ()),
('y_scale', TensorProto.FLOAT, ()),
('y_zero_point', TensorProto.UINT8, ())],
[make_node('QLinearConv', ['x', 'x_scale', 'x_zero_point', 'w', 'w_scale', 'w_zero_point', 'y_scale', 'y_zero_point'], 'y', pads=[1, 1, 2, 0, 1, 2])],
[])
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.UINT8, (30, 50, None, 6, 6))]) # type: ignore

def test_qlinearconv_partial_missing_weight_shape(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.UINT8, (30, 4, 7, 6, 4)),
('x_scale', TensorProto.FLOAT, ()),
('x_zero_point', TensorProto.UINT8, ()),
('w', TensorProto.UINT8, (50, 4, None, 3, 3)),
('w_scale', TensorProto.FLOAT, ()),
('w_zero_point', TensorProto.UINT8, ()),
('y_scale', TensorProto.FLOAT, ()),
('y_zero_point', TensorProto.UINT8, ())],
[make_node('QLinearConv', ['x', 'x_scale', 'x_zero_point', 'w', 'w_scale', 'w_zero_point', 'y_scale', 'y_zero_point'], 'y', pads=[1, 1, 2, 0, 1, 2])],
[])
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.UINT8, None)])

def _make_qlinearmatmul_test(self, shape1, shape2): # type: (Sequence[int], Sequence[int]) -> None
expected_out_shape = np.matmul(np.arange(np.product(shape1)).reshape(shape1),
np.arange(np.product(shape2)).reshape(shape2)).shape
graph = self._make_graph(
[('a', TensorProto.UINT8, shape1),
('a_scale', TensorProto.FLOAT, ()),
('a_zero_point', TensorProto.UINT8, ()),
('b', TensorProto.UINT8, shape2),
('b_scale', TensorProto.FLOAT, ()),
('b_zero_point', TensorProto.UINT8, ()),
('y_scale', TensorProto.FLOAT, ()),
('y_zero_point', TensorProto.UINT8, ())],
[make_node('QLinearMatMul', ['a', 'a_scale', 'a_zero_point', 'b', 'b_scale', 'b_zero_point', 'y_scale', 'y_zero_point'], ['y'])],
[])
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.UINT8, expected_out_shape)])

def test_qlinearmatmul(self): # type: () -> None
self._make_qlinearmatmul_test((3,), (3,))
self._make_qlinearmatmul_test((4, 2), (2, 4))
self._make_qlinearmatmul_test((2,), (2, 3))
self._make_qlinearmatmul_test((4, 2), (2,))
self._make_qlinearmatmul_test((5, 1, 4, 2), (1, 3, 2, 3))
self._make_qlinearmatmul_test((4, 2), (3, 2, 3))

def _make_qlinearmatmul_test_allow_unknown(self, shape1, shape2, expected_out_shape): # type: (Any, Any, Any) -> None
graph = self._make_graph(
[('a', TensorProto.UINT8, shape1),
('a_scale', TensorProto.FLOAT, ()),
('a_zero_point', TensorProto.UINT8, ()),
('b', TensorProto.UINT8, shape2),
('b_scale', TensorProto.FLOAT, ()),
('b_zero_point', TensorProto.UINT8, ()),
('y_scale', TensorProto.FLOAT, ()),
('y_zero_point', TensorProto.UINT8, ())],
[make_node('QLinearMatMul', ['a', 'a_scale', 'a_zero_point', 'b', 'b_scale', 'b_zero_point', 'y_scale', 'y_zero_point'], ['y'])],
[])
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.UINT8, expected_out_shape)])

def test_qlinearmatmul_allow_unknown(self): # type: () -> None
self._make_qlinearmatmul_test_allow_unknown((None,), (None,), ())
self._make_qlinearmatmul_test_allow_unknown((3,), (None,), ())
self._make_qlinearmatmul_test_allow_unknown((2,), (2, "a"), ("a",))
self._make_qlinearmatmul_test_allow_unknown((4, 2), (2, "a"), (4, "a"))
self._make_qlinearmatmul_test_allow_unknown((4, None), (2, "a"), (4, "a"))
self._make_qlinearmatmul_test_allow_unknown((4, None), (None, "a"), (4, "a"))
self._make_qlinearmatmul_test_allow_unknown((1, 4, 2), ("a", 2, 5), ("a", 4, 5))
self._make_qlinearmatmul_test_allow_unknown((1, 3, 4, 2), ("a", 2, 5), (1, 3, 4, 5))

def _make_matmulinteger_test(self, shape1, shape2): # type: (Sequence[int], Sequence[int]) -> None
expected_out_shape = np.matmul(np.arange(np.product(shape1)).reshape(shape1),
np.arange(np.product(shape2)).reshape(shape2)).shape
graph = self._make_graph(
[('A', TensorProto.UINT8, shape1),
('B', TensorProto.UINT8, shape2),
('a_zero_point', TensorProto.UINT8, ()),
('b_zero_point', TensorProto.UINT8, ())],
[make_node('MatMulInteger', ['A', 'B', 'a_zero_point', 'b_zero_point'], ['Y'])],
[])
self._assert_inferred(graph, [make_tensor_value_info('Y', TensorProto.INT32, expected_out_shape)])

def test_matmulinteger(self): # type: () -> None
self._make_matmulinteger_test((2,), (2,))
self._make_matmulinteger_test((1, 2), (2, 3))
self._make_matmulinteger_test((2,), (2, 3))
self._make_matmulinteger_test((4, 2), (2,))
self._make_matmulinteger_test((5, 1, 4, 2), (1, 3, 2, 3))
self._make_matmulinteger_test((4, 2), (3, 2, 3))

def test_quantizelinear(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.FLOAT, (30, 4, 5)),
('y_scale', TensorProto.FLOAT, ()),
('y_zero_point', TensorProto.UINT8, ())],
[make_node('QuantizeLinear', ['x', 'y_scale', 'y_zero_point'], ['y'])],
[])
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.UINT8, (30, 4, 5))])

def test_dequantizelinear(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.UINT8, (30, 4, 5)),
('x_scale', TensorProto.FLOAT, ()),
('x_zero_point', TensorProto.UINT8, ())],
[make_node('DequantizeLinear', ['x', 'x_scale', 'x_zero_point'], ['y'])],
[])
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, (30, 4, 5))])


if __name__ == '__main__':
unittest.main()

0 comments on commit 3717dc6

Please sign in to comment.