Skip to content

Commit 8c09a00

Browse files
janselpytorchmergebot
authored andcommitted
[inductor] Pattern matching engine (copy) (pytorch#93291)
This is an exact duplicate of pytorch#90739 The fbcode workflow for landing that diff seems buggy. The github-export-checks task is failing with credentials errors. Plan to try to land it using GH1. Pull Request resolved: pytorch#93291 Approved by: https://github.com/desertfire
1 parent aee5f84 commit 8c09a00

File tree

10 files changed

+969
-20
lines changed

10 files changed

+969
-20
lines changed

test/inductor/test_pattern_matcher.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Owner(s): ["module: inductor"]
2+
import torch
3+
from torch._dynamo.test_case import run_tests, TestCase
4+
from torch._dynamo.utils import counters
5+
from torch.testing._internal.common_utils import IS_LINUX
6+
from torch.testing._internal.inductor_utils import HAS_CUDA
7+
8+
9+
class TestPaternMatcher(TestCase):
10+
def test_mm_plus_mm(self):
11+
def fn(a, b, c, d):
12+
return torch.add(torch.mm(a, b), torch.mm(c, d))
13+
14+
args = [
15+
torch.randn(16, 16, device="cuda"),
16+
torch.randn(16, 16, device="cuda"),
17+
torch.randn(16, 16, device="cuda"),
18+
torch.randn(16, 16, device="cuda"),
19+
]
20+
expected = fn(*args)
21+
actual = torch.compile(fn)(*args)
22+
torch.testing.assert_close(actual, expected)
23+
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
24+
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 3)
25+
26+
def test_addmm(self):
27+
def fn(a, b, c):
28+
return torch.add(a, torch.mm(b, c)), torch.mm(a, b) + c
29+
30+
args = [
31+
torch.randn(16, 16, device="cuda"),
32+
torch.randn(16, 16, device="cuda"),
33+
torch.randn(16, 16, device="cuda"),
34+
]
35+
e1, e2 = fn(*args)
36+
a1, a2 = torch.compile(fn)(*args)
37+
torch.testing.assert_close(a1, e1)
38+
torch.testing.assert_close(a2, e2)
39+
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2)
40+
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)
41+
42+
def test_cat_mm(self):
43+
def fn(a, b, c):
44+
return torch.cat(
45+
[
46+
torch.mm(a, b),
47+
torch.mm(b, c),
48+
torch.mm(a, c),
49+
],
50+
1,
51+
)
52+
53+
args = [
54+
torch.randn(16, 16, device="cuda"),
55+
torch.randn(16, 16, device="cuda"),
56+
torch.randn(16, 16, device="cuda"),
57+
]
58+
expected = fn(*args)
59+
actual = torch.compile(fn)(*args)
60+
torch.testing.assert_close(actual, expected)
61+
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
62+
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)
63+
64+
def test_cat_addmm(self):
65+
def fn(a, b, c):
66+
return torch.cat(
67+
[
68+
torch.addmm(a, b, c),
69+
torch.addmm(b, c, a),
70+
torch.addmm(c, a, b),
71+
],
72+
1,
73+
)
74+
75+
args = [
76+
torch.randn(16, 16, device="cuda"),
77+
torch.randn(16, 16, device="cuda"),
78+
torch.randn(16, 16, device="cuda"),
79+
]
80+
expected = fn(*args)
81+
actual = torch.compile(fn)(*args)
82+
torch.testing.assert_close(actual, expected)
83+
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
84+
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)
85+
86+
def test_cat_slice_cat(self):
87+
def fn(a, b):
88+
cat_1 = torch.ops.aten.cat.default([a, b], 1)
89+
slice_1 = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807)
90+
slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19)
91+
return torch.ops.aten.cat.default([cat_1, slice_2], 1)
92+
93+
args = [
94+
torch.randn(2, 32, device="cuda"),
95+
torch.randn(2, 16, device="cuda"),
96+
]
97+
expected = fn(*args)
98+
actual = torch.compile(fn)(*args)
99+
torch.testing.assert_close(actual, expected)
100+
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
101+
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)
102+
103+
counters.clear()
104+
args = [
105+
torch.randn(2, 8, device="cuda"),
106+
torch.randn(2, 16, device="cuda"),
107+
]
108+
expected = fn(*args)
109+
actual = torch.compile(fn)(*args)
110+
torch.testing.assert_close(actual, expected)
111+
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
112+
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)
113+
114+
115+
if __name__ == "__main__":
116+
if IS_LINUX and HAS_CUDA:
117+
run_tests()

test/inductor/test_select_algorithm.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,21 @@ def foo(a, b, c):
138138
# Autotuning checks correctness of each version
139139
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
140140

141+
@patches
142+
def test_mm_plus_mm(self):
143+
@torch.compile
144+
def foo(a, b, c, d):
145+
return (a @ b) + (c @ d)
146+
147+
foo(
148+
torch.randn(32, 32, device="cuda"),
149+
torch.randn(32, 32, device="cuda"),
150+
torch.randn(32, 32, device="cuda"),
151+
torch.randn(32, 32, device="cuda"),
152+
)
153+
# Autotuning checks correctness of each version
154+
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
155+
141156

142157
if __name__ == "__main__":
143158
from torch._inductor.utils import is_big_gpu

torch/_inductor/compile_fx.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch._dynamo.utils import fake_mode_from_tensors
2020
from torch._functorch.aot_autograd import make_boxed_func
2121
from torch._subclasses.fake_tensor import FakeTensor
22-
from . import config, metrics, overrides
22+
from . import config, metrics, overrides, pattern_matcher
2323
from .debug import DebugContext
2424
from .decomposition import select_decomp_table
2525
from .graph import GraphLowering
@@ -131,24 +131,29 @@ def compile_fx_inner(
131131
f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
132132
f"graph {graph_id}",
133133
)
134-
135134
V.debug.fx_graph(gm, example_inputs)
136135

137136
if cudagraphs is None:
138137
cudagraphs = config.triton.cudagraphs
139138

140139
shape_env = _shape_env_from_inputs(example_inputs)
141-
fake_mode = fake_mode_from_tensors(example_inputs)
142-
graph = GraphLowering(
143-
gm,
144-
shape_env=shape_env,
145-
num_static_inputs=num_fixed,
146-
graph_id=graph_id,
147-
fake_mode=fake_mode,
148-
)
149-
with V.set_graph_handler(graph):
150-
graph.run(*example_inputs)
151-
compiled_fn = graph.compile_to_fn()
140+
fake_mode = fake_mode_from_tensors(
141+
example_inputs
142+
) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
143+
144+
with V.set_fake_mode(fake_mode):
145+
pattern_matcher.fx_passes(gm)
146+
V.debug.fx_graph_transformed(gm, example_inputs)
147+
148+
graph = GraphLowering(
149+
gm,
150+
shape_env=shape_env,
151+
num_static_inputs=num_fixed,
152+
graph_id=graph_id,
153+
)
154+
with V.set_graph_handler(graph):
155+
graph.run(*example_inputs)
156+
compiled_fn = graph.compile_to_fn()
152157

153158
if cudagraphs:
154159
complex_memory_overlap_inputs = any(

torch/_inductor/config.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@
3737
# do epilogue fusions before other fusions
3838
epilogue_fusion_first = False
3939

40+
# enable pattern match+replace optimizations
41+
pattern_matcher = True
42+
43+
# enable reordering pass
44+
reordering = False
45+
4046
# enable slow autotuning passes to select algorithms
4147
max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"
4248

@@ -181,9 +187,12 @@ class trace:
181187
# Save python logger call >=logging.INFO
182188
info_log = False
183189

184-
# Save input FX graph (post decomps)
190+
# Save input FX graph (post decomps, pre optimization)
185191
fx_graph = True
186192

193+
# Save FX graph after transformations
194+
fx_graph_transformed = True
195+
187196
# Save TorchInductor IR before fusion pass
188197
ir_pre_fusion = True
189198

torch/_inductor/debug.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,12 @@ def fx_graph(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]):
379379
with self.fopen("fx_graph_readable.py") as fd:
380380
fd.write(gm.print_readable(print_output=False))
381381

382+
def fx_graph_transformed(
383+
self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]
384+
):
385+
with self.fopen("fx_graph_transformed.py") as fd:
386+
fd.write(gm.print_readable(print_output=False))
387+
382388
def ir_pre_fusion(self, nodes: SchedulerNodeList):
383389
self._write_ir("ir_pre_fusion.txt", nodes)
384390

torch/_inductor/graph.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,8 @@ def __init__(
9191
shape_env=None,
9292
num_static_inputs=None,
9393
graph_id=None,
94-
fake_mode=None,
9594
):
9695
super().__init__(gm)
97-
if fake_mode is None:
98-
self.fake_mode = torch._subclasses.FakeTensorMode()
99-
else:
100-
self.fake_mode = fake_mode
10196
if shape_env is None:
10297
shape_env = ShapeEnv()
10398
self.reuse_shape_env = False
@@ -133,6 +128,10 @@ def warn_fallback(self, name):
133128
self._warned_fallback.add(name)
134129
log.warning(f"Using FallbackKernel: {name}")
135130

131+
@property
132+
def fake_mode(self):
133+
return V.fake_mode
134+
136135
def get_dtype(self, buffer_name: str):
137136
if buffer_name in self.constants:
138137
return self.constants[buffer_name].dtype
@@ -290,6 +289,10 @@ def call_function(self, target, args, kwargs):
290289
if target is operator.getitem and isinstance(args[0], (list, tuple)):
291290
return super().call_function(target, args, kwargs)
292291

292+
if hasattr(target, "_inductor_lowering_function"):
293+
# passthrough lowerings from .pattern_matcher
294+
return target(*args, **kwargs)
295+
293296
if target not in lowerings:
294297
if config.implicit_fallbacks:
295298
error = (

0 commit comments

Comments
 (0)