Skip to content

Commit

Permalink
[ONNX] Relax sequence tensor dim_param serialization
Browse files Browse the repository at this point in the history
Do not assign dim_param for sequence tensor type.
Sequence of tensors could differ in dimension size.
Use a dimension with neither dim_value nor dim_param set
to denote an unknown dimension.
Create and assign dim_param for normal tensor type.
Pull Request resolved: pytorch#70651
  • Loading branch information
BowenBao authored and pytorchmergebot committed Feb 23, 2022
1 parent 50efa3a commit 4612323
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
37 changes: 36 additions & 1 deletion test/onnx/test_utility_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import torch.utils.cpp_extension
from test_pytorch_common import (skipIfUnsupportedMinOpsetVersion,
skipIfUnsupportedMaxOpsetVersion)
import caffe2.python.onnx.backend as backend
from verify import verify

import torchvision
Expand Down Expand Up @@ -601,6 +600,39 @@ def test_error_on_data_parallel(self):
"unwrap model from torch.nn.DataParallel. Try "):
torch.onnx.export(model, x, f, opset_version=self.opset_version)

@skipIfUnsupportedMinOpsetVersion(11)
def test_sequence_dim(self):
class Module(torch.nn.Module):
def forward(self, x, y):
return [x, y]

model = Module()
# Export with scripting to keep output as Sequence type.
# Tracing unpacks the list.
script_model = torch.jit.script(model)
x = torch.randn(2, 3)

# Case 1: dynamic axis
f = io.BytesIO()
y = torch.randn(2, 3)
torch.onnx.export(script_model, (x, y), f, opset_version=self.opset_version,
input_names=['x', 'y'], dynamic_axes={'y': [1]})
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
loop_output_value_info_proto = onnx_model.graph.output[0]
ref_value_info_proto = onnx.helper.make_tensor_sequence_value_info(loop_output_value_info_proto.name,
1, [2, None])
self.assertEqual(loop_output_value_info_proto, ref_value_info_proto)

# Case 2: no dynamic axes.
f = io.BytesIO()
y = torch.randn(2, 3)
torch.onnx.export(script_model, (x, y), f, opset_version=self.opset_version)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
loop_output_value_info_proto = onnx_model.graph.output[0]
ref_value_info_proto = onnx.helper.make_tensor_sequence_value_info(loop_output_value_info_proto.name,
1, [2, 3])
self.assertEqual(loop_output_value_info_proto, ref_value_info_proto)

def test_export_mode(self):
class MyModule(torch.nn.Module):
def forward(self, x):
Expand Down Expand Up @@ -1022,6 +1054,9 @@ def forward(self, x):
return y

x = torch.tensor([1, 2])
# Move import to local as caffe2 backend requires additional build flag,
# and is only used in this test case.
import caffe2.python.onnx.backend as backend
verify(MyModel(), x, backend, do_constant_folding=False)

def test_fuse_conv_bn(self):
Expand Down
13 changes: 10 additions & 3 deletions torch/csrc/jit/serialization/export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,8 @@ void GraphEncoder::EncodeValueInfoType(
std::unordered_map<int64_t, std::string>>& dynamic_axes) {
auto tensorTypeToONNXType = [&dynamic_axes, n, this](
const TensorTypePtr& t,
onnx::TypeProto_Tensor* onnx_tensor_type) {
onnx::TypeProto_Tensor* onnx_tensor_type,
bool assign_dim_param) {
std::string name = n->debugName();
if (t->dim()) {
onnx::TensorShapeProto* shape = onnx_tensor_type->mutable_shape();
Expand All @@ -559,7 +560,7 @@ void GraphEncoder::EncodeValueInfoType(
}
} else if (sizes[i].is_static()) {
shape->mutable_dim(i)->set_dim_value(sizes[i].static_size());
} else {
} else if (assign_dim_param) {
if (symbol_dim_map_.find(sizes[i]) == symbol_dim_map_.end()) {
if (n->node()->kind() == prim::Param) {
symbol_dim_map_[sizes[i]] = name + "_dim_" + std::to_string(i);
Expand All @@ -584,7 +585,13 @@ void GraphEncoder::EncodeValueInfoType(
// Encode type if either shape or dtype exists.
onnx::TypeProto_Tensor* onnx_tensor_type =
onnx_type->mutable_tensor_type();
tensorTypeToONNXType(tensor_type, onnx_tensor_type);
// Do not assign dim_param for sequence tensor type.
// Sequence of tensors could differ in dimension size.
// Use a dimension with neither dim_value nor dim_param set
// to denote an unknown dimension.
// Create and assign dim_param for normal tensor type.
auto is_sequence_tensor = static_cast<bool>(n->type()->cast<ListType>());
tensorTypeToONNXType(tensor_type, onnx_tensor_type, !is_sequence_tensor);
}
} else if (BoolTypePtr bool_type = node_type->cast<BoolType>()) {
onnx::TypeProto_Tensor* onnx_tensor_type = onnx_type->mutable_tensor_type();
Expand Down

0 comments on commit 4612323

Please sign in to comment.