From c7c723897658eda6298bb74d92e4bb18ab4a5fe3 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Fri, 10 Feb 2023 08:02:32 +0000 Subject: [PATCH] Fix bug in unsqueeze_nested stride calculation (#88688) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88688 Approved by: https://github.com/cpuhrsch --- .../ATen/native/nested/NestedTensorMath.cpp | 2 +- test/test_nestedtensor.py | 19 +++++++++++++++++-- tools/autograd/gen_python_functions.py | 1 - torch/overrides.py | 1 + 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp index 71082f66d71b71..afa00a8e363a19 100644 --- a/aten/src/ATen/native/nested/NestedTensorMath.cpp +++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp @@ -748,7 +748,7 @@ Tensor unsqueeze_nested(const Tensor& self, int64_t dim) { if (wrapped_dim == ndim) { new_stride = stridemat.new_ones({stridemat.size(0), 1}); } else { - new_stride = (stridemat.select(1, mat_dim - 1) * sizemat.select(1, mat_dim - 1)).unsqueeze(-1); + new_stride = (stridemat.select(1, mat_dim) * sizemat.select(1, mat_dim)).unsqueeze(-1); } Tensor stridemat_unsqueezed = at::cat({stridemat.slice(1, 0, mat_dim), new_stride, diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 72a3f4448b8db8..9ef4d0d4cef5e5 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -1641,14 +1641,29 @@ def test_squeeze_unsqueeze(self, device, dtype): self.assertEqual(nt, nt2) # test cases that should work - for i in range(-2, 3): + nt_sizes = nt._nested_tensor_size() + nt_strides = nt._nested_tensor_strides() + for i in range(-2, 4): if (i == 0): + # cannot unsqueeze batch dim continue nt_unsqueezed = nt.unsqueeze(i) - size_idx = i if i < 0 else i - 1 + # negative dim will correspond to unsqueeze() applied at dim = dim + nt.dim() + 1 + wrapped_i = i + nt.dim() + 1 if i < 0 else i + # col_index into nt size tensor is requires subtraction of 1 to ignore batch dim + size_idx = wrapped_i - 1 self.assertEqual(nt_unsqueezed._nested_tensor_size()[:, size_idx], torch.ones(2, dtype=torch.long)) + unsqueezed_stride = nt_unsqueezed._nested_tensor_strides()[:, size_idx] + if (i == nt.ndim or i == -1): + self.assertEqual(unsqueezed_stride, torch.ones(2, dtype=torch.long)) + else: + stride_col_after = nt_strides[:, size_idx] + size_col_after = nt_sizes[:, size_idx] + self.assertEqual(unsqueezed_stride, stride_col_after * size_col_after) nt_squeezed = nt_unsqueezed.squeeze(i) self.assertEqual(nt_squeezed, nt) + self.assertEqual(nt_squeezed._nested_tensor_size(), nt_sizes) + self.assertEqual(nt_squeezed._nested_tensor_strides(), nt_strides) @dtypes(torch.float, torch.float16, torch.double) def test_transpose_inference_mode_interaction(self, device, dtype): diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 06cb7f0d2d50b5..bb3d397402d9ea 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -154,7 +154,6 @@ "fill.Scalar", # only used by the functionalization pass "lift.*", "normal_functional", # only used by the functionalization pas - "_nested_tensor_strides", # don't want to expose this to python "_nested_tensor_offsets", # don't want to expose this to python "_nested_view_from_buffer", # View only version of _nested_from_buffer. This will force users to only use the "safe" version. "_nested_view_from_buffer_copy", diff --git a/torch/overrides.py b/torch/overrides.py index d39fd9ec9b3f48..f84d89e662d181 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -1296,6 +1296,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: Tensor.ndimension: lambda self: -1, Tensor.nelement: lambda self: -1, Tensor._nested_tensor_size: lambda self: -1, + Tensor._nested_tensor_strides: lambda self: -1, Tensor.normal_: lambda self: -1, Tensor.numpy: lambda self: -1, Tensor.permute: lambda self, dim: -1,