Skip to content

Commit 1d73b64

Browse files
anijain2305pytorchmergebot
authored andcommitted
[fake tensor cache] Support index with non bool/int8 indices (pytorch#151477)
Pull Request resolved: pytorch#151477 Approved by: https://github.com/zou3519, https://github.com/bdhirsh ghstack dependencies: pytorch#151409, pytorch#151633
1 parent 41285f2 commit 1d73b64

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

test/test_fake_tensor.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2158,6 +2158,31 @@ def test_cache_tuple_outputs(self):
21582158
extract_tensor_metadata(b),
21592159
)
21602160

2161+
2162+
def test_cache_aten_index(self):
2163+
with FakeTensorMode():
2164+
x = torch.randn(4, 4, 4)
2165+
idx_tensor1 = torch.tensor([0, 2, 3])
2166+
idx_tensor2 = torch.tensor([0, 1, 2])
2167+
2168+
FakeTensorMode.cache_clear()
2169+
self.assertHitsMisses(0, 0)
2170+
2171+
ref = torch.ops.aten.index(x, [None, idx_tensor1, idx_tensor2])
2172+
self.assertHitsMisses(0, 3)
2173+
2174+
res = torch.ops.aten.index(x, [None, idx_tensor1, idx_tensor2])
2175+
self.assertHitsMisses(1, 3)
2176+
self.assertEqual(extract_tensor_metadata(ref), extract_tensor_metadata(res))
2177+
2178+
with FakeTensorMode():
2179+
x = torch.randn(4, 4, 4)
2180+
idx_tensor1 = torch.tensor([True, True, False, True])
2181+
self.assertRaises(DynamicOutputShapeException, lambda: torch.ops.aten.index(x, [None, idx_tensor1]))
2182+
2183+
idx_tensor1 = torch.tensor([1, -2, 3, -4], dtype=torch.int8)
2184+
self.assertRaises(DynamicOutputShapeException, lambda: torch.ops.aten.index(x, [None, idx_tensor1]))
2185+
21612186
@skipIfTorchDynamo("cache hit/miss changes with invoke_subgraph caching")
21622187
def test_invoke_subgraph(self):
21632188
"""

torch/_subclasses/fake_tensor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,6 +1539,21 @@ def _validate_cache_key(
15391539
raise _BypassDispatchCache("data dependent output")
15401540

15411541
if torch.Tag.dynamic_output_shape in func.tags:
1542+
if func is aten.index.Tensor:
1543+
_, new_kwargs = normalize_function( # type: ignore[misc]
1544+
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type]
1545+
)
1546+
for index in new_kwargs["indices"]:
1547+
# index calls nonzero for bool or int8 tensors, and
1548+
# therefore has a dynamic shape output. For other dtypes,
1549+
# the output shape depends on the input shape (and not data)
1550+
if isinstance(index, torch.Tensor) and index.dtype in (
1551+
torch.bool,
1552+
torch.int8,
1553+
):
1554+
raise _BypassDispatchCache("dynamic output shape")
1555+
return
1556+
15421557
raise _BypassDispatchCache("dynamic output shape")
15431558

15441559
if torch.Tag.inplace_view in func.tags:

0 commit comments

Comments
 (0)