Skip to content

Commit

Permalink
[inductor] Pattern matching engine (copy) (pytorch#93291)
Browse files Browse the repository at this point in the history
This is an exact duplicate of pytorch#90739

The fbcode workflow for landing that diff seems buggy.  The github-export-checks task is failing with credentials errors.  Plan to try to land it using GH1.

Pull Request resolved: pytorch#93291
Approved by: https://github.com/desertfire
  • Loading branch information
jansel authored and pytorchmergebot committed Jan 31, 2023
1 parent aee5f84 commit 8c09a00
Show file tree
Hide file tree
Showing 10 changed files with 969 additions and 20 deletions.
117 changes: 117 additions & 0 deletions test/inductor/test_pattern_matcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Owner(s): ["module: inductor"]
import torch
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import counters
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA


class TestPaternMatcher(TestCase):
def test_mm_plus_mm(self):
def fn(a, b, c, d):
return torch.add(torch.mm(a, b), torch.mm(c, d))

args = [
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 3)

def test_addmm(self):
def fn(a, b, c):
return torch.add(a, torch.mm(b, c)), torch.mm(a, b) + c

args = [
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
]
e1, e2 = fn(*args)
a1, a2 = torch.compile(fn)(*args)
torch.testing.assert_close(a1, e1)
torch.testing.assert_close(a2, e2)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)

def test_cat_mm(self):
def fn(a, b, c):
return torch.cat(
[
torch.mm(a, b),
torch.mm(b, c),
torch.mm(a, c),
],
1,
)

args = [
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)

def test_cat_addmm(self):
def fn(a, b, c):
return torch.cat(
[
torch.addmm(a, b, c),
torch.addmm(b, c, a),
torch.addmm(c, a, b),
],
1,
)

args = [
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)

def test_cat_slice_cat(self):
def fn(a, b):
cat_1 = torch.ops.aten.cat.default([a, b], 1)
slice_1 = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807)
slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19)
return torch.ops.aten.cat.default([cat_1, slice_2], 1)

args = [
torch.randn(2, 32, device="cuda"),
torch.randn(2, 16, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)

counters.clear()
args = [
torch.randn(2, 8, device="cuda"),
torch.randn(2, 16, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)


if __name__ == "__main__":
if IS_LINUX and HAS_CUDA:
run_tests()
15 changes: 15 additions & 0 deletions test/inductor/test_select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,21 @@ def foo(a, b, c):
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@patches
def test_mm_plus_mm(self):
@torch.compile
def foo(a, b, c, d):
return (a @ b) + (c @ d)

foo(
torch.randn(32, 32, device="cuda"),
torch.randn(32, 32, device="cuda"),
torch.randn(32, 32, device="cuda"),
torch.randn(32, 32, device="cuda"),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)


if __name__ == "__main__":
from torch._inductor.utils import is_big_gpu
Expand Down
31 changes: 18 additions & 13 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch._dynamo.utils import fake_mode_from_tensors
from torch._functorch.aot_autograd import make_boxed_func
from torch._subclasses.fake_tensor import FakeTensor
from . import config, metrics, overrides
from . import config, metrics, overrides, pattern_matcher
from .debug import DebugContext
from .decomposition import select_decomp_table
from .graph import GraphLowering
Expand Down Expand Up @@ -131,24 +131,29 @@ def compile_fx_inner(
f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
f"graph {graph_id}",
)

V.debug.fx_graph(gm, example_inputs)

if cudagraphs is None:
cudagraphs = config.triton.cudagraphs

shape_env = _shape_env_from_inputs(example_inputs)
fake_mode = fake_mode_from_tensors(example_inputs)
graph = GraphLowering(
gm,
shape_env=shape_env,
num_static_inputs=num_fixed,
graph_id=graph_id,
fake_mode=fake_mode,
)
with V.set_graph_handler(graph):
graph.run(*example_inputs)
compiled_fn = graph.compile_to_fn()
fake_mode = fake_mode_from_tensors(
example_inputs
) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)

with V.set_fake_mode(fake_mode):
pattern_matcher.fx_passes(gm)
V.debug.fx_graph_transformed(gm, example_inputs)

graph = GraphLowering(
gm,
shape_env=shape_env,
num_static_inputs=num_fixed,
graph_id=graph_id,
)
with V.set_graph_handler(graph):
graph.run(*example_inputs)
compiled_fn = graph.compile_to_fn()

if cudagraphs:
complex_memory_overlap_inputs = any(
Expand Down
11 changes: 10 additions & 1 deletion torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
# do epilogue fusions before other fusions
epilogue_fusion_first = False

# enable pattern match+replace optimizations
pattern_matcher = True

# enable reordering pass
reordering = False

# enable slow autotuning passes to select algorithms
max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"

Expand Down Expand Up @@ -181,9 +187,12 @@ class trace:
# Save python logger call >=logging.INFO
info_log = False

# Save input FX graph (post decomps)
# Save input FX graph (post decomps, pre optimization)
fx_graph = True

# Save FX graph after transformations
fx_graph_transformed = True

# Save TorchInductor IR before fusion pass
ir_pre_fusion = True

Expand Down
6 changes: 6 additions & 0 deletions torch/_inductor/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,12 @@ def fx_graph(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]):
with self.fopen("fx_graph_readable.py") as fd:
fd.write(gm.print_readable(print_output=False))

def fx_graph_transformed(
self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]
):
with self.fopen("fx_graph_transformed.py") as fd:
fd.write(gm.print_readable(print_output=False))

def ir_pre_fusion(self, nodes: SchedulerNodeList):
self._write_ir("ir_pre_fusion.txt", nodes)

Expand Down
13 changes: 8 additions & 5 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,8 @@ def __init__(
shape_env=None,
num_static_inputs=None,
graph_id=None,
fake_mode=None,
):
super().__init__(gm)
if fake_mode is None:
self.fake_mode = torch._subclasses.FakeTensorMode()
else:
self.fake_mode = fake_mode
if shape_env is None:
shape_env = ShapeEnv()
self.reuse_shape_env = False
Expand Down Expand Up @@ -133,6 +128,10 @@ def warn_fallback(self, name):
self._warned_fallback.add(name)
log.warning(f"Using FallbackKernel: {name}")

@property
def fake_mode(self):
return V.fake_mode

def get_dtype(self, buffer_name: str):
if buffer_name in self.constants:
return self.constants[buffer_name].dtype
Expand Down Expand Up @@ -290,6 +289,10 @@ def call_function(self, target, args, kwargs):
if target is operator.getitem and isinstance(args[0], (list, tuple)):
return super().call_function(target, args, kwargs)

if hasattr(target, "_inductor_lowering_function"):
# passthrough lowerings from .pattern_matcher
return target(*args, **kwargs)

if target not in lowerings:
if config.implicit_fallbacks:
error = (
Expand Down
Loading

0 comments on commit 8c09a00

Please sign in to comment.