Skip to content

Commit

Permalink
[inductor] Fix pattern replacements with multiple users (pytorch#129689)
Browse files Browse the repository at this point in the history
Fixes pytorch#129685

After matching a pattern, we currently try to remove all the nodes of that
pattern, which doesn't work if any intermediate node has users outside of the
pattern. In which case we can't delete those particular nodes.

Pull Request resolved: pytorch#129689
Approved by: https://github.com/shunting314
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Jun 28, 2024
1 parent 7854d84 commit b019f38
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
14 changes: 14 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10607,6 +10607,20 @@ def fn(primals_5):
actual = compiled_fn(torch.ones(s0, s1))
self.assertTrue((actual == 1).all())

def test_pattern_matcher_multi_user(self):
# Reproducer for https://github.com/pytorch/pytorch/issues/129685

def forward(float_1, view_1):
logits = float_1 / 64.0
loss = torch.nn.functional.cross_entropy(logits, view_1, ignore_index=5)
logsumexp = logits.logsumexp(dim=-1)
return [loss, logsumexp]

a = torch.randn(512, 4096, requires_grad=True)
b = torch.randint(size=(512,), low=0, high=4095)

self.common(forward, (a, b))


@dataclasses.dataclass
class TestFailure:
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def __repr__(self) -> str:

def erase_nodes(self, graph: torch.fx.Graph) -> None:
for n in reversed(self.nodes):
if not n._erased:
if not n._erased and not n.users:
graph.erase_node(n)

def output_nodes(self) -> List[Optional[torch.fx.Node]]:
Expand Down

0 comments on commit b019f38

Please sign in to comment.