Skip to content

Commit 5b69b87

Browse files
Revert "Symintify getitem and add the required helper functions (pytorch#86207)"
This reverts commit fd5085c. Reverted pytorch#86207 on behalf of https://github.com/seemethere due to Fails internal tests, see: https://www.internalfb.com/intern/sandcastle/job/22517998926071860/insights
1 parent 75df4b5 commit 5b69b87

File tree

9 files changed

+83
-157
lines changed

9 files changed

+83
-157
lines changed

aten/src/ATen/TensorIndexing.h

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <ATen/Functions.h>
55
#include <ATen/ScalarOps.h>
66
#include <ATen/core/TensorBody.h>
7-
#include <c10/core/SymInt.h>
87
#include <c10/util/Optional.h>
98
#include <c10/util/irange.h>
109

@@ -212,7 +211,7 @@ static inline Tensor applySlice(
212211
int64_t step,
213212
bool disable_slice_optimization,
214213
const at::Device& self_device,
215-
const c10::optional<SymIntArrayRef>& self_sizes) {
214+
const c10::optional<IntArrayRef>& self_sizes) {
216215
// TODO: implement negative step
217216
TORCH_CHECK_VALUE(step > 0, "step must be greater than zero");
218217

@@ -221,10 +220,10 @@ static inline Tensor applySlice(
221220
// Skip this optimization if we are tracing, as the trace may be polymorphic
222221
// over the shape of the `self` tensor, and we still want to record
223222
// the slice.
224-
SymInt length = (self_device == at::kCPU || self_device == at::kCUDA)
223+
int64_t length = (self_device == at::kCPU || self_device == at::kCUDA)
225224
? (*self_sizes)[dim]
226225
: self.size(dim);
227-
if (!disable_slice_optimization && start == 0 && length == stop &&
226+
if (!disable_slice_optimization && start == 0 && stop == length &&
228227
step == 1) {
229228
return self;
230229
}
@@ -238,17 +237,17 @@ static inline Tensor applySelect(
238237
int64_t index,
239238
int64_t real_dim,
240239
const at::Device& /*self_device*/,
241-
const c10::optional<SymIntArrayRef>& self_sizes) {
240+
const c10::optional<IntArrayRef>& self_sizes) {
242241
// See NOTE [nested tensor size for indexing]
243242
if (self_sizes.has_value()) {
244243
TORCH_CHECK_INDEX(
245244
!(index == 0 && dim == 0 && self_sizes->size() == 0),
246245
"invalid index of a 0-dim tensor. ",
247246
"Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");
248247

249-
auto size = (*self_sizes)[dim];
248+
int64_t size = (*self_sizes)[dim];
250249
TORCH_CHECK_INDEX(
251-
size >= -index && size > index,
250+
index >= -size && index < size,
252251
"index ",
253252
index,
254253
" is out of bounds for dimension ",
@@ -425,7 +424,7 @@ static inline Tensor handleDimInMultiDimIndexing(
425424
std::vector<Tensor>& outIndices,
426425
bool disable_slice_optimization,
427426
const at::Device& original_tensor_device,
428-
const c10::optional<SymIntArrayRef>& prev_dim_result_sizes) {
427+
const c10::optional<IntArrayRef>& prev_dim_result_sizes) {
429428
if (index.is_integer()) {
430429
return impl::applySelect(
431430
prev_dim_result,
@@ -509,7 +508,7 @@ static inline Tensor applySlicing(
509508
std::vector<Tensor>& outIndices,
510509
bool disable_slice_optimization,
511510
const at::Device& self_device,
512-
const c10::optional<SymIntArrayRef>& self_sizes) {
511+
const c10::optional<IntArrayRef>& self_sizes) {
513512
int64_t dim = 0;
514513
int64_t specified_dims = impl::count_specified_dimensions(indices);
515514

@@ -525,9 +524,9 @@ static inline Tensor applySlicing(
525524
for (const auto i : c10::irange(indices.size())) {
526525
auto& obj = indices[i];
527526
// See NOTE [nested tensor size for indexing]
528-
c10::optional<SymIntArrayRef> result_sizes = result.is_nested()
529-
? c10::optional<SymIntArrayRef>(c10::nullopt)
530-
: c10::optional<SymIntArrayRef>(result.sym_sizes());
527+
c10::optional<IntArrayRef> result_sizes = result.is_nested()
528+
? c10::optional<IntArrayRef>(c10::nullopt)
529+
: c10::optional<IntArrayRef>(result.sizes());
531530
result = handleDimInMultiDimIndexing(
532531
/*prev_dim_result=*/result,
533532
/*original_tensor=*/self,
@@ -601,9 +600,9 @@ static inline Tensor get_item(
601600
// nested tensor does not have a size (yet) so for now we represent its size
602601
// as null may need to be changed after we reach a better solution for nested
603602
// tensor size
604-
c10::optional<SymIntArrayRef> self_sizes = self.is_nested()
605-
? c10::optional<SymIntArrayRef>(c10::nullopt)
606-
: c10::optional<SymIntArrayRef>(self.sym_sizes());
603+
c10::optional<IntArrayRef> self_sizes = self.is_nested()
604+
? c10::optional<IntArrayRef>(c10::nullopt)
605+
: c10::optional<IntArrayRef>(self.sizes());
607606

608607
// handle simple types: integers, slices, none, ellipsis, bool
609608
if (indices.size() == 1) {
@@ -664,7 +663,7 @@ static inline void set_item(
664663
const Tensor& value,
665664
bool disable_slice_optimization = false) {
666665
at::Device self_device = self.device();
667-
SymIntArrayRef self_sizes = self.sym_sizes();
666+
IntArrayRef self_sizes = self.sizes();
668667

669668
// handle simple types: integers, slices, ellipsis, bool
670669
if (indices.size() == 1) {

aten/src/ATen/native/TensorShape.cpp

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,49 +1512,39 @@ QuantizerPtr create_subtensor_quantizer(const Tensor& self, bool is_select, int6
15121512
return quantizer;
15131513
}
15141514

1515-
Tensor select(const Tensor& self, int64_t dim, int64_t index_) {
1515+
Tensor select(const Tensor& self, int64_t dim, int64_t index) {
15161516
int64_t ndim = self.dim();
15171517
if (ndim == 0) {
15181518
TORCH_CHECK_INDEX(false, "select() cannot be applied to a 0-dim tensor.");
15191519
}
15201520
dim = maybe_wrap_dim(dim, ndim);
1521-
auto size = self.sym_sizes()[dim];
1522-
if (size < -index_ || size <= index_) {
1521+
auto size = self.size(dim);
1522+
if (index < -size || index >= size) {
15231523
if (self.has_names() && self.names()[dim] != Dimname::wildcard()) {
1524-
TORCH_CHECK_INDEX(false, "select(): index ", index_, " out of range for tensor of size ",
1524+
TORCH_CHECK_INDEX(false, "select(): index ", index, " out of range for tensor of size ",
15251525
self.sizes(), " at dimension ", self.names()[dim]);
15261526
}
1527-
TORCH_CHECK_INDEX(false, "select(): index ", index_, " out of range for tensor of size ",
1527+
TORCH_CHECK_INDEX(false, "select(): index ", index, " out of range for tensor of size ",
15281528
self.sizes(), " at dimension ", dim);
15291529
}
1530-
SymInt index = index_;
15311530
if (index < 0) {
15321531
index += size;
15331532
}
15341533
if (self.is_sparse()) {
1535-
return select_sparse(self, dim, index.guard_int(__FILE__, __LINE__));
1534+
return select_sparse(self, dim, index);
15361535
}
1536+
DimVector sizes(self.sizes().begin(), self.sizes().end());
1537+
DimVector strides(self.strides().begin(), self.strides().end());
1538+
auto storage_offset = self.storage_offset() + index * strides[dim];
1539+
sizes.erase(sizes.begin() + dim);
1540+
strides.erase(strides.begin() + dim);
15371541

15381542
Tensor result;
15391543
if (self.is_quantized()) {
1540-
auto local_index = index.guard_int(__FILE__, __LINE__);
1541-
1542-
DimVector sizes(self.sizes().begin(), self.sizes().end());
1543-
DimVector strides(self.strides().begin(), self.strides().end());
1544-
auto storage_offset = self.storage_offset() + local_index * strides[dim];
1545-
sizes.erase(sizes.begin() + dim);
1546-
strides.erase(strides.begin() + dim);
1547-
1548-
auto quantizer = create_subtensor_quantizer(self, true, local_index, local_index + 1, dim, 1);
1544+
auto quantizer = create_subtensor_quantizer(self, true, index, index + 1, dim, 1);
15491545
result = as_strided_qtensorimpl(self, sizes, strides, storage_offset, quantizer);
15501546
} else {
1551-
std::vector<c10::SymInt> sizes(self.sym_sizes().begin(), self.sym_sizes().end());
1552-
std::vector<c10::SymInt> strides(self.sym_strides().begin(), self.sym_strides().end());
1553-
auto storage_offset = self.sym_storage_offset() + index * strides[dim];
1554-
sizes.erase(sizes.begin() + dim);
1555-
strides.erase(strides.begin() + dim);
1556-
1557-
result = self.as_strided_symint(sizes, strides, storage_offset);
1547+
result = self.as_strided(sizes, strides, storage_offset);
15581548
}
15591549
namedinference::propagate_names_except(result, self, {dim});
15601550
return result;

functorch/test/test_aotdispatch.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
skip,
3838
skipOps,
3939
)
40-
from torch._subclasses.fake_tensor import DynamicOutputShapeException
4140

4241
USE_TORCHVISION = False
4342
try:
@@ -725,6 +724,7 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
725724
}
726725

727726
symbolic_aot_autograd_failures = {
727+
xfail('__getitem__', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
728728
xfail('__rmatmul__', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
729729
xfail('addbmm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
730730
xfail('addcdiv', ''), # aten.fill_.Scalar - couldn't find symbolic meta function/decomposition
@@ -790,6 +790,7 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
790790
xfail('hsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
791791
xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
792792
xfail('index_copy', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
793+
xfail('index_fill', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
793794
xfail('index_put', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
794795
xfail('index_select', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
795796
xfail('inner', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
@@ -986,7 +987,11 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
986987
xfail('scatter_reduce', 'sum'), # aten.scatter_reduce.two - couldn't find symbolic meta function/decomp...
987988
xfail('segment_reduce', 'lengths'), # aten.segment_reduce.default - couldn't find symbolic meta functio...
988989
xfail('segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta functio...
990+
xfail('select', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
991+
xfail('select_scatter', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
989992
xfail('sgn', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
993+
xfail('slice', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
994+
xfail('slice_scatter', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
990995
xfail('sort', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
991996
xfail('special.entr', ''), # aten.special_entr.default - couldn't find symbolic meta function/decomposition
992997
xfail('special.erfcx', ''), # aten.special_erfcx.default - couldn't find symbolic meta function/decompos...
@@ -999,6 +1004,7 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
9991004
xfail('split', 'list_args'), # Cannot call sizes() on tensor with symbolic sizes/strides
10001005
xfail('split_with_sizes', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
10011006
xfail('squeeze', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
1007+
xfail('stack', ''), # aten.select.int - couldn't find symbolic meta function/decomposition
10021008
xfail('std', ''), # Cannot call numel() on tensor with symbolic sizes/strides
10031009
xfail('std_mean', ''), # Cannot call numel() on tensor with symbolic sizes/strides
10041010
xfail('stft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
@@ -1067,33 +1073,30 @@ def get_grads(args):
10671073

10681074
compiled_f = compiled_function(f, nop, nop)
10691075

1070-
try:
1071-
reset_grads()
1072-
call_forwards_backwards(compiled_f)
1073-
compiled_grad = get_grads(args)
1074-
1075-
reset_grads()
1076-
call_forwards_backwards(f)
1077-
orig_grad = get_grads(args)
1078-
self.assertEqual(orig_grad, compiled_grad)
1079-
1080-
def create_new_arg(x):
1081-
if isinstance(x, torch.Tensor) and x.dtype == torch.float32:
1082-
return x.detach().uniform_(0, 1).requires_grad_(x.requires_grad)
1083-
return x
1084-
1085-
args = pytree.tree_map(create_new_arg, args)
1086-
1087-
reset_grads()
1088-
call_forwards_backwards(compiled_f)
1089-
compiled_grad = get_grads(args)
1090-
1091-
reset_grads()
1092-
call_forwards_backwards(f)
1093-
orig_grad = get_grads(args)
1094-
self.assertEqual(orig_grad, compiled_grad)
1095-
except DynamicOutputShapeException:
1096-
self.skipTest("Dynamic output shape operation in trace")
1076+
reset_grads()
1077+
call_forwards_backwards(compiled_f)
1078+
compiled_grad = get_grads(args)
1079+
1080+
reset_grads()
1081+
call_forwards_backwards(f)
1082+
orig_grad = get_grads(args)
1083+
self.assertEqual(orig_grad, compiled_grad)
1084+
1085+
def create_new_arg(x):
1086+
if isinstance(x, torch.Tensor) and x.dtype == torch.float32:
1087+
return x.detach().uniform_(0, 1).requires_grad_(x.requires_grad)
1088+
return x
1089+
1090+
args = pytree.tree_map(create_new_arg, args)
1091+
1092+
reset_grads()
1093+
call_forwards_backwards(compiled_f)
1094+
compiled_grad = get_grads(args)
1095+
1096+
reset_grads()
1097+
call_forwards_backwards(f)
1098+
orig_grad = get_grads(args)
1099+
self.assertEqual(orig_grad, compiled_grad)
10971100

10981101
class TestEagerFusionOpInfo(AOTTestCase):
10991102
@ops(op_db, allowed_dtypes=(torch.float,))

test/test_autograd.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5269,16 +5269,13 @@ def test_grad_fn_attr_bindings(self):
52695269
self.assertEqual(out.grad_fn._saved_indices, (None, indices)) # c10::List<c10::optional<Tensor>> -> Tuple[Tensor?]
52705270
self.assertIsInstance(out.grad_fn._saved_indices[1], torch.Tensor)
52715271
self.assertIsInstance(out.grad_fn._raw_saved_indices[1], torch._C._autograd.SavedTensor)
5272-
self.assertEqual(out.grad_fn._saved_self_sym_sizes, a.shape) # SymIntArrayRef -> Tuple[SymInt]
5273-
self.assertIsInstance(out.grad_fn._saved_self_sym_sizes[0], int)
5272+
self.assertEqual(out.grad_fn._saved_self_sizes, a.shape) # IntArrayRef -> Tuple[int]
5273+
self.assertIsInstance(out.grad_fn._saved_self_sizes[0], int)
52745274

52755275
out.grad_fn._raw_saved_indices[1].register_hooks(lambda x: x, lambda x: x)
52765276
with self.assertRaisesRegex(RuntimeError, "None is forbidden"):
52775277
out.grad_fn._raw_saved_indices[0].register_hooks(lambda x: x, lambda x: x)
52785278

5279-
out = a.mean()
5280-
self.assertEqual(out.grad_fn._saved_self_sizes, a.shape) # IntArrayRef -> Tuple[int]
5281-
52825279
a = torch.ones(2, 2, requires_grad=True)
52835280
out = a * a
52845281
out.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x)
@@ -5297,24 +5294,6 @@ def test_grad_fn_attr_bindings(self):
52975294
else:
52985295
self.assertIsNone(out.grad_fn._saved_scales) # c10::optional<ArrayRef<double>> -> float[]?
52995296

5300-
a = torch.ones(1, 1, 3, 3, requires_grad=True)
5301-
out = nn.Conv2d(1, 1, 3)(a)
5302-
self.assertEqual(out.grad_fn._saved_bias_sym_sizes_opt, (1,)) # c10::optional<SymIntArrayRef> -> SymInt[]?
5303-
out = nn.Conv2d(1, 1, 3, bias=False)(a)
5304-
# TODO: This is BAD! we converted a c10::nullopt into a (0,)
5305-
self.assertEqual(out.grad_fn._saved_bias_sym_sizes_opt, (0,))
5306-
5307-
a = torch.ones(1, 3, 3, requires_grad=True)
5308-
out = torch.addbmm(a.squeeze(0), a, a)
5309-
self.assertEqual(out.grad_fn._saved_batch1_argsize_0, 1) # int64_t
5310-
self.assertEqual(out.grad_fn._saved_batch1_argsize_1, 3) # int64_t
5311-
5312-
a = torch.ones(1, 1, 3, 3, requires_grad=True)
5313-
out = torch.nn.functional.unfold(a, 3)
5314-
self.assertEqual(out.grad_fn._saved_self_sym_argsize_minus_2, 3) # SymInt
5315-
self.assertEqual(out.grad_fn._saved_self_sym_argsize_minus_1, 3) # SymInt
5316-
5317-
a = torch.ones(1, 1, 2, requires_grad=True)
53185297
out = torch.nn.functional.interpolate(a, scale_factor=0.5, mode="linear")
53195298
self.assertIsNone(out.grad_fn._saved_output_size)
53205299
self.assertEqual(out.grad_fn._saved_scale_factors, (0.5,))

test/test_proxy_tensor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,6 +1032,7 @@ def f(a, b, c, d, e):
10321032
xfail('linalg.eig'),
10331033
xfail('linalg.eigvals'),
10341034
skip('masked.logsumexp', ''), # Tensors of type TensorImpl do not have numel
1035+
xfail('__getitem__', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
10351036
xfail('masked.amax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
10361037
xfail('masked.amin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
10371038
xfail('masked.argmax', ''), # aten.argmax.default - couldn't find symbolic meta function/decomposition
@@ -1108,6 +1109,7 @@ def f(a, b, c, d, e):
11081109
xfail('hsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
11091110
xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
11101111
xfail('index_copy', ''), # Expected a long tensor for index, but got Float
1112+
xfail('index_fill', ''), # aten.index_fill.int_Scalar - couldn't find symbolic meta function/decomposition
11111113
xfail('index_reduce', ''), # Float
11121114
xfail('inner', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
11131115
xfail('isclose', ''), # The underlying op of 'aten.stride' has no overload name '_schema'
@@ -1257,6 +1259,9 @@ def f(a, b, c, d, e):
12571259
xfail('scatter_reduce', 'sum'), # aten.scatter_reduce.two - couldn't find symbolic meta function/decomposition
12581260
xfail('searchsorted', ''), # Could not run 'aten::searchsorted.Tensor' with arguments from the 'Meta' backend. ...
12591261
xfail('segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta function/decomposition
1262+
xfail('select', ''), # aten.select.int - couldn't find symbolic meta function/decomposition
1263+
xfail('select_scatter', ''), # aten.select_scatter.default - couldn't find symbolic meta function/decomposition
1264+
xfail('slice_scatter', ''), # aten.slice_scatter.default - couldn't find symbolic meta function/decomposition
12601265
xfail('sort', ''), # aten.sort.default - couldn't find symbolic meta function/decomposition
12611266
xfail('special.airy_ai', ''), # aten.special_airy_ai.default - couldn't find symbolic meta function/decomposition
12621267
xfail('special.bessel_y0', ''), # aten.special_bessel_y0.default - couldn't find symbolic meta function/decomposition

tools/autograd/derivatives.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,7 @@
781781
other: -grad * exp((self - 1) * log(other) - other - lgamma(self))
782782

783783
- name: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
784-
self: index_backward(grad.new_zeros_symint(self.sym_sizes(), self.options()), indices, grad)
784+
self: index_backward(grad.new_zeros(self.sizes(), self.options()), indices, grad)
785785
result: auto_linear
786786

787787
- name: index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor
@@ -1388,7 +1388,7 @@
13881388
- name: select.int(Tensor(a) self, int dim, int index) -> Tensor(a)
13891389
dispatch:
13901390
Default:
1391-
self: select_backward_symint(grad, self.sym_sizes(), dim, index)
1391+
self: select_backward(grad, self.sizes(), dim, index)
13921392
result: auto_linear
13931393
AutogradNestedTensor:
13941394
self: _nested_select_backward(grad, self, dim, index)

tools/autograd/load_derivatives.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -814,24 +814,12 @@ def stride_expr(name: str) -> str:
814814
),
815815
# replace self.size(2) with self_size_2
816816
(
817-
r"{}.size\((-?\w+)\)",
817+
r"{}.size\((\w+)\)",
818818
{
819-
"suffix": lambda m: "_argsize_{}".format(
820-
m.groups()[0].replace("-", "minus_")
821-
),
819+
"suffix": lambda m: "_argsize_{}".format(*m.groups()),
822820
"nctype": lambda name: NamedCType(name, BaseCType(longT)),
823821
},
824822
),
825-
# replace self.sym_size(2) with self_sym_size_2
826-
(
827-
r"{}.sym_size\((-?\w+)\)",
828-
{
829-
"suffix": lambda m: "_sym_argsize_{}".format(
830-
m.groups()[0].replace("-", "minus_")
831-
),
832-
"nctype": lambda name: NamedCType(name, BaseCType(SymIntT)),
833-
},
834-
),
835823
# replace self.numel() with self_numel
836824
(
837825
r"{}.numel\(\)",

0 commit comments

Comments
 (0)