Skip to content

Commit

Permalink
[JIT] Cat shape analysis fix for -1 dim (pytorch#72616)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#72616

Following logic [here](https://codebrowser.bddppq.com/pytorch/pytorch/aten/src/ATen/WrapDimUtils.h.html#_ZN2atL19legacy_cat_wrap_dimElN3c108ArrayRefINS_6TensorEEE)

The prior version was checking if dim was not None, we should be checking if it's None. Strangely, the shape analysis still worked because the negative indexing just wrapped around, however it would lead to errors in executing shape functions. In follow up I will extend shape functions testing to actually invoke shape functions as well to catch this type of bug.

This wasn't caught in the nnc opinfo tests bc nnc was already failing for cat single-node :'(

Test Plan: Imported from OSS

Reviewed By: Krovatkin

Differential Revision: D34117930

Pulled By: eellison

fbshipit-source-id: 2c60430d7144dc828a6a4789e0015b83153f7a32
  • Loading branch information
Elias Ellison authored and facebook-github-bot committed Feb 10, 2022
1 parent 629ac57 commit 3ee8207
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
24 changes: 15 additions & 9 deletions test/jit/test_symbolic_shape_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,21 +460,27 @@ def test_sym_ir_parsing(self):
self.assertEqual(out, [-2, -3])

def test_stitching_concat(self):

@torch.jit.script
def foo(a, b, x, y):
def foo1(a, b, x, y):
return (a / b) + torch.cat([x, y])

g = foo.graph
for inp in foo.graph.inputs():
inp.setType(inp.type().with_sizes([None, None]))
@torch.jit.script
def foo2(a, b, x, y):
return (a / b) + torch.cat([x, y], dim=-2)

for foo in [foo1, foo2]:
g = foo.graph
for inp in foo.graph.inputs():
inp.setType(inp.type().with_sizes([None, None]))

shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(foo.graph)
nodes = [g.findNode("aten::div")] + [g.findNode("aten::add")] + [g.findNode("aten::cat")]
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(foo.graph)
nodes = [g.findNode("aten::div")] + [g.findNode("aten::add")] + [g.findNode("aten::cat")]

inps = [1, 10], [20, 10], [15, 1], [5, 1]
output_shapes = [[20, 10], [20, 10], [20, 1]]
inps = [1, 10], [20, 10], [15, 1], [5, 1]
output_shapes = [[20, 10], [20, 10], [20, 1]]

self.checkSymShapeCompute(shape_compute_graph, nodes, output_shapes, inps)
self.checkSymShapeCompute(shape_compute_graph, nodes, output_shapes, inps)

@unittest.skipIf(not hasattr(torch.jit, "_shapes"), "shape functions not loaded in python")
def test_shape_function_includes(self):
Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/jit/runtime/shape_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@ def check_cat_no_zero_dim(tensors: List[List[int]]):
for tensor in tensors:
assert len(tensor) > 0
def legacy_cat_wrap_dim(dim: int, tensor_sizes: List[List[int]]):
out_dim: Optional[int] = None
for size in tensor_sizes:
if len(size) != 0 and size != [0] and out_dim is not None:
out_dim = maybe_wrap_dim(dim, len(size))
if not (len(size) == 1 and size[0] == 0):
if out_dim is None:
out_dim = maybe_wrap_dim(dim, len(size))
if out_dim is None:
out_dim = dim
return out_dim
Expand Down

0 comments on commit 3ee8207

Please sign in to comment.