Skip to content

Commit

Permalink
[ONNX] Fix onnx gather shape inference
Browse files Browse the repository at this point in the history
Previous code sets `only_rank_available=true` for Gather, resulting in overriding actual inferred shape values with symbols.

Fixes pytorch#68003

Pull Request resolved: pytorch#73607
Approved by: https://github.com/fatcat-z, https://github.com/garymm
  • Loading branch information
BowenBao authored and pytorchmergebot committed Mar 8, 2022
1 parent 1fbc08c commit a482fd7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
17 changes: 17 additions & 0 deletions test/onnx/test_pytorch_onnx_shape_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,23 @@ def test_constant_of_shape_dynamic(self):
constant_of_shape = g.op("ConstantOfShape", shape, value_t=torch.tensor([2.0]))
self.run_test(g, constant_of_shape.node(), expect_tensor("Float", shape=(None, None, None, None)))

def test_gather_dynamic_index(self):
g = self.create_empty_graph()
input = g.addInput()
input.setType(input.type().with_dtype(torch.float).with_sizes([None, 3, 16, 16]))
indices = g.addInput()
indices.setType(indices.type().with_dtype(torch.int64).with_sizes([None]))
output = g.op("Gather", input, indices, axis_i=1)
self.run_test(g, output.node(), expect_tensor("Float", shape=([None, None, 16, 16])))

def test_gather_scalar_index(self):
g = self.create_empty_graph()
input = g.addInput()
input.setType(input.type().with_dtype(torch.float).with_sizes([None, 3, 16, 16]))
indices = self.insert_tensor_constant(g, torch.tensor(1))
output = g.op("Gather", input, indices, axis_i=1)
self.run_test(g, output.node(), expect_tensor("Float", shape=([None, 16, 16])))

def test_reshape(self):
g = self.create_empty_graph()
constant = self.insert_tensor_constant(g, torch.ones(2, 16, 5, 5))
Expand Down
12 changes: 3 additions & 9 deletions torch/csrc/jit/passes/onnx/shape_type_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1275,17 +1275,11 @@ void ComputeConstant(Node* n, int opset_version) {
break;
}
case ::c10::onnx::Gather: {
if (ConstantValueMap::HasRank(n->input(0)->debugName()) &&
ConstantValueMap::HasRank(n->input(1)->debugName())) {
auto rank_0 =
ConstantValueMap::GetRank(n->input(0)->debugName()).value();
auto rank_1 =
ConstantValueMap::GetRank(n->input(1)->debugName()).value();
only_rank_available = true;
rank = rank_0 + rank_1 - 1;
}
if (ConstantValueMap::HasShapeValue(n->input(0)->debugName()) &&
ConstantValueMap::HasValue(n->input(1)->debugName())) {
// Special case for pattern Shape -> Gather, to propagate shape value.
// Gather input 0 is 1d tensor, Gather input 1 is scalar.
// Gather output will be scalar.
auto shape_value =
ConstantValueMap::GetShapeValue(n->input(0)->debugName()).value();
auto idx_value =
Expand Down

0 comments on commit a482fd7

Please sign in to comment.