Skip to content

Commit

Permalink
Fix bug in unsqueeze_nested stride calculation (pytorch#88688)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#88688
Approved by: https://github.com/cpuhrsch
  • Loading branch information
mikaylagawarecki authored and pytorchmergebot committed Feb 10, 2023
1 parent 889a464 commit c7c7238
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 4 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/nested/NestedTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ Tensor unsqueeze_nested(const Tensor& self, int64_t dim) {
if (wrapped_dim == ndim) {
new_stride = stridemat.new_ones({stridemat.size(0), 1});
} else {
new_stride = (stridemat.select(1, mat_dim - 1) * sizemat.select(1, mat_dim - 1)).unsqueeze(-1);
new_stride = (stridemat.select(1, mat_dim) * sizemat.select(1, mat_dim)).unsqueeze(-1);
}
Tensor stridemat_unsqueezed = at::cat({stridemat.slice(1, 0, mat_dim),
new_stride,
Expand Down
19 changes: 17 additions & 2 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1641,14 +1641,29 @@ def test_squeeze_unsqueeze(self, device, dtype):
self.assertEqual(nt, nt2)

# test cases that should work
for i in range(-2, 3):
nt_sizes = nt._nested_tensor_size()
nt_strides = nt._nested_tensor_strides()
for i in range(-2, 4):
if (i == 0):
# cannot unsqueeze batch dim
continue
nt_unsqueezed = nt.unsqueeze(i)
size_idx = i if i < 0 else i - 1
# negative dim will correspond to unsqueeze() applied at dim = dim + nt.dim() + 1
wrapped_i = i + nt.dim() + 1 if i < 0 else i
# col_index into nt size tensor is requires subtraction of 1 to ignore batch dim
size_idx = wrapped_i - 1
self.assertEqual(nt_unsqueezed._nested_tensor_size()[:, size_idx], torch.ones(2, dtype=torch.long))
unsqueezed_stride = nt_unsqueezed._nested_tensor_strides()[:, size_idx]
if (i == nt.ndim or i == -1):
self.assertEqual(unsqueezed_stride, torch.ones(2, dtype=torch.long))
else:
stride_col_after = nt_strides[:, size_idx]
size_col_after = nt_sizes[:, size_idx]
self.assertEqual(unsqueezed_stride, stride_col_after * size_col_after)
nt_squeezed = nt_unsqueezed.squeeze(i)
self.assertEqual(nt_squeezed, nt)
self.assertEqual(nt_squeezed._nested_tensor_size(), nt_sizes)
self.assertEqual(nt_squeezed._nested_tensor_strides(), nt_strides)

@dtypes(torch.float, torch.float16, torch.double)
def test_transpose_inference_mode_interaction(self, device, dtype):
Expand Down
1 change: 0 additions & 1 deletion tools/autograd/gen_python_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@
"fill.Scalar", # only used by the functionalization pass
"lift.*",
"normal_functional", # only used by the functionalization pas
"_nested_tensor_strides", # don't want to expose this to python
"_nested_tensor_offsets", # don't want to expose this to python
"_nested_view_from_buffer", # View only version of _nested_from_buffer. This will force users to only use the "safe" version.
"_nested_view_from_buffer_copy",
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.ndimension: lambda self: -1,
Tensor.nelement: lambda self: -1,
Tensor._nested_tensor_size: lambda self: -1,
Tensor._nested_tensor_strides: lambda self: -1,
Tensor.normal_: lambda self: -1,
Tensor.numpy: lambda self: -1,
Tensor.permute: lambda self, dim: -1,
Expand Down

0 comments on commit c7c7238

Please sign in to comment.