Skip to content

[inductor] Pattern matching engine (copy) #93291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions test/inductor/test_pattern_matcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Owner(s): ["module: inductor"]
import torch
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import counters
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA


class TestPaternMatcher(TestCase):
def test_mm_plus_mm(self):
def fn(a, b, c, d):
return torch.add(torch.mm(a, b), torch.mm(c, d))

args = [
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 3)

def test_addmm(self):
def fn(a, b, c):
return torch.add(a, torch.mm(b, c)), torch.mm(a, b) + c

args = [
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
]
e1, e2 = fn(*args)
a1, a2 = torch.compile(fn)(*args)
torch.testing.assert_close(a1, e1)
torch.testing.assert_close(a2, e2)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)

def test_cat_mm(self):
def fn(a, b, c):
return torch.cat(
[
torch.mm(a, b),
torch.mm(b, c),
torch.mm(a, c),
],
1,
)

args = [
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)

def test_cat_addmm(self):
def fn(a, b, c):
return torch.cat(
[
torch.addmm(a, b, c),
torch.addmm(b, c, a),
torch.addmm(c, a, b),
],
1,
)

args = [
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)

def test_cat_slice_cat(self):
def fn(a, b):
cat_1 = torch.ops.aten.cat.default([a, b], 1)
slice_1 = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807)
slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19)
return torch.ops.aten.cat.default([cat_1, slice_2], 1)

args = [
torch.randn(2, 32, device="cuda"),
torch.randn(2, 16, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)

counters.clear()
args = [
torch.randn(2, 8, device="cuda"),
torch.randn(2, 16, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)


if __name__ == "__main__":
if IS_LINUX and HAS_CUDA:
run_tests()
15 changes: 15 additions & 0 deletions test/inductor/test_select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,21 @@ def foo(a, b, c):
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@patches
def test_mm_plus_mm(self):
@torch.compile
def foo(a, b, c, d):
return (a @ b) + (c @ d)

foo(
torch.randn(32, 32, device="cuda"),
torch.randn(32, 32, device="cuda"),
torch.randn(32, 32, device="cuda"),
torch.randn(32, 32, device="cuda"),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)


if __name__ == "__main__":
from torch._inductor.utils import is_big_gpu
Expand Down
31 changes: 18 additions & 13 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch._dynamo.utils import fake_mode_from_tensors
from torch._functorch.aot_autograd import make_boxed_func
from torch._subclasses.fake_tensor import FakeTensor
from . import config, metrics, overrides
from . import config, metrics, overrides, pattern_matcher
from .debug import DebugContext
from .decomposition import select_decomp_table
from .graph import GraphLowering
Expand Down Expand Up @@ -131,24 +131,29 @@ def compile_fx_inner(
f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
f"graph {graph_id}",
)

V.debug.fx_graph(gm, example_inputs)

if cudagraphs is None:
cudagraphs = config.triton.cudagraphs

shape_env = _shape_env_from_inputs(example_inputs)
fake_mode = fake_mode_from_tensors(example_inputs)
graph = GraphLowering(
gm,
shape_env=shape_env,
num_static_inputs=num_fixed,
graph_id=graph_id,
fake_mode=fake_mode,
)
with V.set_graph_handler(graph):
graph.run(*example_inputs)
compiled_fn = graph.compile_to_fn()
fake_mode = fake_mode_from_tensors(
example_inputs
) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)

with V.set_fake_mode(fake_mode):
pattern_matcher.fx_passes(gm)
V.debug.fx_graph_transformed(gm, example_inputs)

graph = GraphLowering(
gm,
shape_env=shape_env,
num_static_inputs=num_fixed,
graph_id=graph_id,
)
with V.set_graph_handler(graph):
graph.run(*example_inputs)
compiled_fn = graph.compile_to_fn()

if cudagraphs:
complex_memory_overlap_inputs = any(
Expand Down
11 changes: 10 additions & 1 deletion torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
# do epilogue fusions before other fusions
epilogue_fusion_first = False

# enable pattern match+replace optimizations
pattern_matcher = True

# enable reordering pass
reordering = False

# enable slow autotuning passes to select algorithms
max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"

Expand Down Expand Up @@ -181,9 +187,12 @@ class trace:
# Save python logger call >=logging.INFO
info_log = False

# Save input FX graph (post decomps)
# Save input FX graph (post decomps, pre optimization)
fx_graph = True

# Save FX graph after transformations
fx_graph_transformed = True

# Save TorchInductor IR before fusion pass
ir_pre_fusion = True

Expand Down
6 changes: 6 additions & 0 deletions torch/_inductor/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,12 @@ def fx_graph(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]):
with self.fopen("fx_graph_readable.py") as fd:
fd.write(gm.print_readable(print_output=False))

def fx_graph_transformed(
self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]
):
with self.fopen("fx_graph_transformed.py") as fd:
fd.write(gm.print_readable(print_output=False))

def ir_pre_fusion(self, nodes: SchedulerNodeList):
self._write_ir("ir_pre_fusion.txt", nodes)

Expand Down
13 changes: 8 additions & 5 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,8 @@ def __init__(
shape_env=None,
num_static_inputs=None,
graph_id=None,
fake_mode=None,
):
super().__init__(gm)
if fake_mode is None:
self.fake_mode = torch._subclasses.FakeTensorMode()
else:
self.fake_mode = fake_mode
if shape_env is None:
shape_env = ShapeEnv()
self.reuse_shape_env = False
Expand Down Expand Up @@ -133,6 +128,10 @@ def warn_fallback(self, name):
self._warned_fallback.add(name)
log.warning(f"Using FallbackKernel: {name}")

@property
def fake_mode(self):
return V.fake_mode

def get_dtype(self, buffer_name: str):
if buffer_name in self.constants:
return self.constants[buffer_name].dtype
Expand Down Expand Up @@ -290,6 +289,10 @@ def call_function(self, target, args, kwargs):
if target is operator.getitem and isinstance(args[0], (list, tuple)):
return super().call_function(target, args, kwargs)

if hasattr(target, "_inductor_lowering_function"):
# passthrough lowerings from .pattern_matcher
return target(*args, **kwargs)

if target not in lowerings:
if config.implicit_fallbacks:
error = (
Expand Down
Loading