From 828e36c75d8e3f782587575e3dc47611faa6432b Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Wed, 9 Feb 2022 19:04:45 -0800 Subject: [PATCH] [JIT] Cat shape analysis fix for -1 dim (#72616) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72616 Following logic [here](https://codebrowser.bddppq.com/pytorch/pytorch/aten/src/ATen/WrapDimUtils.h.html#_ZN2atL19legacy_cat_wrap_dimElN3c108ArrayRefINS_6TensorEEE) The prior version was checking if dim was not None, we should be checking if it's None. Strangely, the shape analysis still worked because the negative indexing just wrapped around, however it would lead to errors in executing shape functions. In follow up I will extend shape functions testing to actually invoke shape functions as well to catch this type of bug. This wasn't caught in the nnc opinfo tests bc nnc was already failing for cat single-node :'( Test Plan: Imported from OSS Reviewed By: Krovatkin Differential Revision: D34117930 Pulled By: eellison fbshipit-source-id: 2c60430d7144dc828a6a4789e0015b83153f7a32 (cherry picked from commit 3ee820753fca0b82c4647aa4fadff5e1e62f2d48) --- test/jit/test_symbolic_shape_analysis.py | 24 +++++++++++++++--------- torch/csrc/jit/runtime/shape_functions.h | 6 +++--- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/test/jit/test_symbolic_shape_analysis.py b/test/jit/test_symbolic_shape_analysis.py index 4aae489938960..cd25caa92b2bb 100644 --- a/test/jit/test_symbolic_shape_analysis.py +++ b/test/jit/test_symbolic_shape_analysis.py @@ -460,21 +460,27 @@ def test_sym_ir_parsing(self): self.assertEqual(out, [-2, -3]) def test_stitching_concat(self): + @torch.jit.script - def foo(a, b, x, y): + def foo1(a, b, x, y): return (a / b) + torch.cat([x, y]) - g = foo.graph - for inp in foo.graph.inputs(): - inp.setType(inp.type().with_sizes([None, None])) + @torch.jit.script + def foo2(a, b, x, y): + return (a / b) + torch.cat([x, y], dim=-2) + + for foo in [foo1, foo2]: + g = foo.graph + for inp in foo.graph.inputs(): + inp.setType(inp.type().with_sizes([None, None])) - shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(foo.graph) - nodes = [g.findNode("aten::div")] + [g.findNode("aten::add")] + [g.findNode("aten::cat")] + shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(foo.graph) + nodes = [g.findNode("aten::div")] + [g.findNode("aten::add")] + [g.findNode("aten::cat")] - inps = [1, 10], [20, 10], [15, 1], [5, 1] - output_shapes = [[20, 10], [20, 10], [20, 1]] + inps = [1, 10], [20, 10], [15, 1], [5, 1] + output_shapes = [[20, 10], [20, 10], [20, 1]] - self.checkSymShapeCompute(shape_compute_graph, nodes, output_shapes, inps) + self.checkSymShapeCompute(shape_compute_graph, nodes, output_shapes, inps) @unittest.skipIf(not hasattr(torch.jit, "_shapes"), "shape functions not loaded in python") def test_shape_function_includes(self): diff --git a/torch/csrc/jit/runtime/shape_functions.h b/torch/csrc/jit/runtime/shape_functions.h index 9cd2e8f5e5507..7de7dadb5f662 100644 --- a/torch/csrc/jit/runtime/shape_functions.h +++ b/torch/csrc/jit/runtime/shape_functions.h @@ -93,12 +93,12 @@ def check_cat_no_zero_dim(tensors: List[List[int]]): for tensor in tensors: assert len(tensor) > 0 - def legacy_cat_wrap_dim(dim: int, tensor_sizes: List[List[int]]): out_dim: Optional[int] = None for size in tensor_sizes: - if len(size) != 0 and size != [0] and out_dim is not None: - out_dim = maybe_wrap_dim(dim, len(size)) + if not (len(size) == 1 and size[0] == 0): + if out_dim is None: + out_dim = maybe_wrap_dim(dim, len(size)) if out_dim is None: out_dim = dim return out_dim