diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py new file mode 100644 index 0000000000000..7bba18e6bf8ce --- /dev/null +++ b/test/inductor/test_pattern_matcher.py @@ -0,0 +1,117 @@ +# Owner(s): ["module: inductor"] +import torch +from torch._dynamo.test_case import run_tests, TestCase +from torch._dynamo.utils import counters +from torch.testing._internal.common_utils import IS_LINUX +from torch.testing._internal.inductor_utils import HAS_CUDA + + +class TestPaternMatcher(TestCase): + def test_mm_plus_mm(self): + def fn(a, b, c, d): + return torch.add(torch.mm(a, b), torch.mm(c, d)) + + args = [ + torch.randn(16, 16, device="cuda"), + torch.randn(16, 16, device="cuda"), + torch.randn(16, 16, device="cuda"), + torch.randn(16, 16, device="cuda"), + ] + expected = fn(*args) + actual = torch.compile(fn)(*args) + torch.testing.assert_close(actual, expected) + self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1) + self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 3) + + def test_addmm(self): + def fn(a, b, c): + return torch.add(a, torch.mm(b, c)), torch.mm(a, b) + c + + args = [ + torch.randn(16, 16, device="cuda"), + torch.randn(16, 16, device="cuda"), + torch.randn(16, 16, device="cuda"), + ] + e1, e2 = fn(*args) + a1, a2 = torch.compile(fn)(*args) + torch.testing.assert_close(a1, e1) + torch.testing.assert_close(a2, e2) + self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2) + self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4) + + def test_cat_mm(self): + def fn(a, b, c): + return torch.cat( + [ + torch.mm(a, b), + torch.mm(b, c), + torch.mm(a, c), + ], + 1, + ) + + args = [ + torch.randn(16, 16, device="cuda"), + torch.randn(16, 16, device="cuda"), + torch.randn(16, 16, device="cuda"), + ] + expected = fn(*args) + actual = torch.compile(fn)(*args) + torch.testing.assert_close(actual, expected) + self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1) + self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4) + + def test_cat_addmm(self): + def fn(a, b, c): + return torch.cat( + [ + torch.addmm(a, b, c), + torch.addmm(b, c, a), + torch.addmm(c, a, b), + ], + 1, + ) + + args = [ + torch.randn(16, 16, device="cuda"), + torch.randn(16, 16, device="cuda"), + torch.randn(16, 16, device="cuda"), + ] + expected = fn(*args) + actual = torch.compile(fn)(*args) + torch.testing.assert_close(actual, expected) + self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1) + self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4) + + def test_cat_slice_cat(self): + def fn(a, b): + cat_1 = torch.ops.aten.cat.default([a, b], 1) + slice_1 = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807) + slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19) + return torch.ops.aten.cat.default([cat_1, slice_2], 1) + + args = [ + torch.randn(2, 32, device="cuda"), + torch.randn(2, 16, device="cuda"), + ] + expected = fn(*args) + actual = torch.compile(fn)(*args) + torch.testing.assert_close(actual, expected) + self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1) + self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4) + + counters.clear() + args = [ + torch.randn(2, 8, device="cuda"), + torch.randn(2, 16, device="cuda"), + ] + expected = fn(*args) + actual = torch.compile(fn)(*args) + torch.testing.assert_close(actual, expected) + self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1) + self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4) + + +if __name__ == "__main__": + if IS_LINUX and HAS_CUDA: + run_tests() diff --git a/test/inductor/test_select_algorithm.py b/test/inductor/test_select_algorithm.py index ffc0003e71125..008973ee23c1f 100644 --- a/test/inductor/test_select_algorithm.py +++ b/test/inductor/test_select_algorithm.py @@ -138,6 +138,21 @@ def foo(a, b, c): # Autotuning checks correctness of each version self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + @patches + def test_mm_plus_mm(self): + @torch.compile + def foo(a, b, c, d): + return (a @ b) + (c @ d) + + foo( + torch.randn(32, 32, device="cuda"), + torch.randn(32, 32, device="cuda"), + torch.randn(32, 32, device="cuda"), + torch.randn(32, 32, device="cuda"), + ) + # Autotuning checks correctness of each version + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + if __name__ == "__main__": from torch._inductor.utils import is_big_gpu diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index b62a0d0db3244..1a5d2a68e6cbb 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -19,7 +19,7 @@ from torch._dynamo.utils import fake_mode_from_tensors from torch._functorch.aot_autograd import make_boxed_func from torch._subclasses.fake_tensor import FakeTensor -from . import config, metrics, overrides +from . import config, metrics, overrides, pattern_matcher from .debug import DebugContext from .decomposition import select_decomp_table from .graph import GraphLowering @@ -131,24 +131,29 @@ def compile_fx_inner( f"{'BACKWARDS' if is_backward else 'FORWARDS'} " f"graph {graph_id}", ) - V.debug.fx_graph(gm, example_inputs) if cudagraphs is None: cudagraphs = config.triton.cudagraphs shape_env = _shape_env_from_inputs(example_inputs) - fake_mode = fake_mode_from_tensors(example_inputs) - graph = GraphLowering( - gm, - shape_env=shape_env, - num_static_inputs=num_fixed, - graph_id=graph_id, - fake_mode=fake_mode, - ) - with V.set_graph_handler(graph): - graph.run(*example_inputs) - compiled_fn = graph.compile_to_fn() + fake_mode = fake_mode_from_tensors( + example_inputs + ) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) + + with V.set_fake_mode(fake_mode): + pattern_matcher.fx_passes(gm) + V.debug.fx_graph_transformed(gm, example_inputs) + + graph = GraphLowering( + gm, + shape_env=shape_env, + num_static_inputs=num_fixed, + graph_id=graph_id, + ) + with V.set_graph_handler(graph): + graph.run(*example_inputs) + compiled_fn = graph.compile_to_fn() if cudagraphs: complex_memory_overlap_inputs = any( diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 7a7e17c70eb97..e1ff535fabe49 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -37,6 +37,12 @@ # do epilogue fusions before other fusions epilogue_fusion_first = False +# enable pattern match+replace optimizations +pattern_matcher = True + +# enable reordering pass +reordering = False + # enable slow autotuning passes to select algorithms max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1" @@ -181,9 +187,12 @@ class trace: # Save python logger call >=logging.INFO info_log = False - # Save input FX graph (post decomps) + # Save input FX graph (post decomps, pre optimization) fx_graph = True + # Save FX graph after transformations + fx_graph_transformed = True + # Save TorchInductor IR before fusion pass ir_pre_fusion = True diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index 111a21c23d8c2..5e51cbbaceadf 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -379,6 +379,12 @@ def fx_graph(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]): with self.fopen("fx_graph_readable.py") as fd: fd.write(gm.print_readable(print_output=False)) + def fx_graph_transformed( + self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor] + ): + with self.fopen("fx_graph_transformed.py") as fd: + fd.write(gm.print_readable(print_output=False)) + def ir_pre_fusion(self, nodes: SchedulerNodeList): self._write_ir("ir_pre_fusion.txt", nodes) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 50660f5cf6d0f..0a6f5bc78b1a9 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -91,13 +91,8 @@ def __init__( shape_env=None, num_static_inputs=None, graph_id=None, - fake_mode=None, ): super().__init__(gm) - if fake_mode is None: - self.fake_mode = torch._subclasses.FakeTensorMode() - else: - self.fake_mode = fake_mode if shape_env is None: shape_env = ShapeEnv() self.reuse_shape_env = False @@ -133,6 +128,10 @@ def warn_fallback(self, name): self._warned_fallback.add(name) log.warning(f"Using FallbackKernel: {name}") + @property + def fake_mode(self): + return V.fake_mode + def get_dtype(self, buffer_name: str): if buffer_name in self.constants: return self.constants[buffer_name].dtype @@ -290,6 +289,10 @@ def call_function(self, target, args, kwargs): if target is operator.getitem and isinstance(args[0], (list, tuple)): return super().call_function(target, args, kwargs) + if hasattr(target, "_inductor_lowering_function"): + # passthrough lowerings from .pattern_matcher + return target(*args, **kwargs) + if target not in lowerings: if config.implicit_fallbacks: error = ( diff --git a/torch/_inductor/kernel/mm_plus_mm.py b/torch/_inductor/kernel/mm_plus_mm.py new file mode 100644 index 0000000000000..d7bd381d21a31 --- /dev/null +++ b/torch/_inductor/kernel/mm_plus_mm.py @@ -0,0 +1,174 @@ +import functools + +import torch +from ..lowering import lowerings +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + TritonTemplate, +) +from ..utils import use_triton_template +from ..virtualized import V +from .mm_common import mm_args, mm_grid, mm_options + +aten = torch.ops.aten + + +def ref_mm_plus_mm(a, b, c, d, out): + torch.mm(a, b, out=out) + out.addmm_(c, d) + return out + + +aten_mm_plus_mm = ExternKernelChoice(ref_mm_plus_mm) + +mm_plus_mm_template = TritonTemplate( + name="mm_plus_mm", + grid=mm_grid, + debug=False, + source=r""" +{{def_kernel("A", "B", "C", "D")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K1 = {{size("A", 1)}} + # K2 = {{size("C", 1)}} + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + stride_cm = {{stride("C", 0)}} + stride_ck = {{stride("C", 1)}} + stride_dk = {{stride("D", 0)}} + stride_dn = {{stride("D", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + C = C + (ram[:, None] * stride_cm + rk[None, :] * stride_ck) + D = D + (rk[:, None] * stride_dk + rbn[None, :] * stride_dn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k1 in range(K1, 0, -BLOCK_K): + # First matmul with A @ B + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k1, other=0.) + b = tl.load(B, mask=rk[:, None] < k1, other=0.) + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + # Splitting this into two loops causes an internal triton LLVM error + # https://github.com/openai/triton/issues/967 + # for k2 in range(K2, 0, -BLOCK_K): + k2 = k1 + + # Second matmul with C @ D + if EVEN_K: + c = tl.load(C) + d = tl.load(D) + else: + c = tl.load(C, mask=rk[None, :] < k2, other=0.) + d = tl.load(D, mask=rk[:, None] < k2, other=0.) + acc += tl.dot(c, d, allow_tf32=ALLOW_TF32) + C += BLOCK_K * stride_ck + D += BLOCK_K * stride_dk + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask")}} +""", +) + + +@functools.lru_cache(None) +def mm_configs(): + import triton + + # these have been tweaked to workaround register issues + return [ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=2, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=3, num_warps=8 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=16 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32}, num_stages=4, num_warps=8 + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=8 + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=1, num_warps=8 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=1, num_warps=8 + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=1, num_warps=8 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16}, num_stages=2, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16}, num_stages=1, num_warps=2 + ), + ] + + +def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None): + """ + Computes mm(mat1, mat2) + mm(mat3, mat4) + """ + if not V.graph.sizevars.maybe_guard_list_equals( + mat1.get_size(), mat3.get_size() + ) or not V.graph.sizevars.maybe_guard_list_equals(mat2.get_size(), mat4.get_size()): + # TODO(jansel): support different K values when this is fixed: + # https://github.com/openai/triton/issues/967 + return lowerings[aten.addmm](lowerings[aten.mm](mat1, mat2), mat3, mat4) + + m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) + m, n, k, layout, mat3, mat4 = mm_args(mat3, mat4, layout=layout) + + # options to tune from + choices = [aten_mm_plus_mm.bind((mat1, mat2, mat3, mat4), layout)] + if use_triton_template(layout): + for config in mm_configs(): + choices.append( + mm_plus_mm_template.generate( + (mat1, mat2, mat3, mat4), + layout, + **mm_options(config, k, layout), + ) + ) + + return autotune_select_algorithm(choices, [mat1, mat2, mat3, mat4], layout) diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py new file mode 100644 index 0000000000000..db70da6a6d185 --- /dev/null +++ b/torch/_inductor/pattern_matcher.py @@ -0,0 +1,609 @@ +import dataclasses +import functools +import inspect +import itertools +import logging +import operator +import os +from collections import defaultdict +from typing import Any, Callable, List, Union + +import torch +import torch._inductor as inductor +import torch.fx +import torch.utils._pytree as pytree +from torch._dynamo.utils import counters +from torch.fx.immutable_collections import immutable_dict, immutable_list + +from . import config, ir +from .lowering import lowerings as L +from .virtualized import V + +log = logging.getLogger(__name__) +aten = torch.ops.aten + +Constant = Any +NodeOrConstant = Union[Constant, torch.fx.Node] + + +class Match: + """ + Represents a successfully matched pattern. + """ + + def __init__(self, pattern, args=None, kwargs=None): + super().__init__() + self.pattern = pattern + # The input nodes that must be passed in to the result + self.args = args or [] + self.kwargs = kwargs or {} + # The nodes matched in this expression + self.nodes = [] + # Mapping CallFunction to the node.target + self.targets = {} + + def extend(self, other): + if self.kwargs: + for key in set(self.kwargs.keys()) & set(other.kwargs.keys()): + if self.kwargs[key] != other.kwargs[key]: + raise FailedMatch(f"kwarg mismatch: {key}") + self.args.extend(other.args) + self.nodes.extend(other.nodes) + self.kwargs.update(other.kwargs) + self.targets.update(other.targets) + + def bundle(self): + # Wrap args in an extra list + self.args = [tuple(self.args)] + return self + + def __repr__(self): + return f"Match(..., {self.args}, {self.kwargs})" + + def erase_nodes(self, graph: torch.fx.Graph): + for n in reversed(self.nodes): + graph.erase_node(n) + + +class FailedMatch(RuntimeError): + def __bool__(self): + return False + + +class MatchContext: + """ + State needed while running PatternExpr._match(). + """ + + def __init__(self, outputs: List["PatternExpr"]): + self.outputs = outputs + self.pattern_to_node = {} + + def match(self, pattern, node): + """wrapper to check reused nodes in patterns""" + if pattern in self.pattern_to_node: + if self.pattern_to_node[pattern] == node: + return Match(pattern) # already checked this node + else: + return FailedMatch("repeated pattern differs") + m = pattern._match(node, self) + assert pattern not in self.pattern_to_node + self.pattern_to_node[pattern] = node if m else None + return m + + +class PatternExpr: + """ + Base class for types of patterns + """ + + def _match(self, node: torch.fx.Node, outputs) -> Union[Match, FailedMatch]: + raise NotImplementedError() + + def match(self, node: torch.fx.Node) -> Union[Match, FailedMatch]: + try: + return MatchContext([self]).match(self, node) + except FailedMatch as e: + return e + + def __repr__(self): + return self.__class__.__name__ + "()" + + +class Arg(PatternExpr): + """ + Capture an arg which will become an input to the handler. Args are + passed in depth first order. + """ + + def _match(self, node: NodeOrConstant, ctx: MatchContext): + return Match(self, args=[node]) # matches anything + + +class KeywordArg(PatternExpr): + """ + Capture a kwarg which will become an input to the handler. + """ + + def __init__(self, name): + super().__init__() + self.name = name + + def _match(self, node: NodeOrConstant, ctx: MatchContext): + return Match(self, kwargs={self.name: node}) # matches anything + + +class CallFunction(PatternExpr): + """ + Matches a call_function node in the FX graps: `fns[i](*args, **kwargs)` + """ + + def __init__(self, fns, *args, _users=1, **kwargs): + super().__init__() + fns = [fns] if callable(fns) else list(fns) + for fn in list(fns): + if isinstance(fn, torch._ops.OpOverloadPacket): + fns.extend([getattr(fn, overload) for overload in fn.overloads()]) + + self.fns = fns + self.fns_set = set(fns) + self.args = tuple(args) + self.kwargs = dict(kwargs) + self.users = _users + if any( + isinstance(x, (dict, list, tuple)) + for x in itertools.chain(args, kwargs.values()) + ): + self.flatten = self.pytree_flatten + else: + self.flatten = self.simple_flatten + self.flat_args_kwargs = self.flatten(self.args, self.kwargs) + + @staticmethod + def simple_flatten(args, kwargs): + return (*args, *kwargs.values()), (len(args), *kwargs.keys()) + + @staticmethod + def pytree_flatten(args, kwargs): + def norm_spec(s: pytree.TreeSpec): + if s.type is None: + return s + mapping = {immutable_list: list, tuple: list, immutable_dict: dict} + return pytree.TreeSpec( + mapping.get(s.type, s.type), + s.context, + list(map(norm_spec, s.children_specs)), + ) + + flat, spec = pytree.tree_flatten([args, kwargs]) + spec = norm_spec(spec) + return flat, spec + + def __repr__(self): + args = [ + f"[{self.fns[0].__name__}, ...]", + *map(repr, self.args), + *[f"{k}={v}" for k, v in self.kwargs.items()], + ] + return f"{self.__class__.__name__}({', '.join(args)})" + + def _match(self, node: torch.fx.Node, ctx: MatchContext): + if ( + not isinstance(node, torch.fx.Node) + or node.op != "call_function" + or node.target not in self.fns_set + or len(node.args) != len(self.args) + or len(node.kwargs) != len(self.kwargs) + ): + return FailedMatch("function_mismatch") + + if self not in ctx.outputs and len(node.users) != self.users: + return FailedMatch("multiple_users") + + node_items, node_spec = self.flatten(node.args, node.kwargs) + self_items, self_spec = self.flat_args_kwargs + if node_spec != self_spec: + return FailedMatch(f"args_stucture {node_spec} {self_spec}") + assert len(node_items) == len(self_items) + + m = Match(self) + for i, pattern, child_node in zip(itertools.count(), self_items, node_items): + if isinstance(pattern, PatternExpr): + child_match = ctx.match(pattern, child_node) + if not child_match: + return FailedMatch(f"arg[{i}]: {child_match}") + m.extend(child_match) + elif isinstance(child_node, torch.fx.Node) or child_node != pattern: + return FailedMatch("constant_args") + m.nodes.append(node) + m.targets[self] = node.target + return m + + +class ListOf(PatternExpr): + """ + Matches a repeated pattern + """ + + def __init__(self, pattern): + super().__init__() + assert isinstance(pattern, PatternExpr) + self.pattern = pattern + + def __repr__(self): + return f"{self.__class__.__name__}({self.pattern})" + + def _match(self, node: List[torch.fx.Node], ctx: MatchContext): + if not isinstance(node, (list, tuple)) or len(node) == 0: + return FailedMatch("non_list") + m = Match(self) + for i, child_node in enumerate(node): + child_match = MatchContext(ctx.outputs).match(self.pattern, child_node) + if not child_match: + return FailedMatch(f"list[{i}]: {child_match}") + m.extend(child_match.bundle()) + return m.bundle() + + +pass_patterns = [ + defaultdict(list), + defaultdict(list), + defaultdict(list), +] + + +@dataclasses.dataclass +class PatternEntry: + pattern: PatternExpr + extra_check: Callable[[Match], bool] + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node): + raise NotImplementedError() + + def register(self, pass_number, target): + if isinstance(pass_number, int): + pass_patterns[pass_number][target].append(self) + else: + for x in pass_number: + self.register(x, target) + + +@dataclasses.dataclass +class LoweringPatternEntry(PatternEntry): + handler: Any + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node): + handler = functools.wraps(self.handler)(functools.partial(self.handler, match)) + with graph.inserting_before(node): + replacement = graph.call_function(handler, tuple(match.args), match.kwargs) + replacement.meta.update(node.meta) + node.replace_all_uses_with(replacement) + assert match.nodes[-1] is node + match.erase_nodes(graph) + + +@dataclasses.dataclass +class ReplacementPatternEntry(PatternEntry): + replacement_graph: torch.fx.GraphModule + signature: inspect.Signature + propagate: bool = False + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node): + class Replacer(torch.fx.Interpreter): + call_method = None + call_module = None + get_attr = None + + def call_function(self, target, args, kwargs): + result = graph.call_function(target, args, kwargs) + if propagate and V.fake_mode: + fargs, fkwargs = torch.fx.map_arg( + (args, kwargs), lambda n: n.meta["val"] + ) + with V.fake_mode: + result.meta["val"] = target(*fargs, **fkwargs) + return result + + propagate = self.propagate + norm_args = self.signature.bind(*match.args, **match.kwargs) + with graph.inserting_before(node): + replacement = Replacer(self.replacement_graph).run( + *norm_args.arguments.values() + ) + replacement.meta.update(node.meta) + node.replace_all_uses_with(replacement) + assert match.nodes[-1] is node + match.erase_nodes(graph) + + +def _return_true(match): + return True + + +def register_replacement_pattern(pattern, extra_check=_return_true, pass_number=1): + """ + Register an aten to aten replacement pattern + """ + + def decorator(handler): + signature = inspect.signature(handler) + replacement_graph = torch.fx.symbolic_trace(handler) + for target in pattern.fns: + ReplacementPatternEntry( + pattern=pattern, + extra_check=extra_check, + replacement_graph=replacement_graph, + signature=signature, + ).register(pass_number, target) + return handler + + assert isinstance(pattern, CallFunction) + return decorator + + +def register_lowering_pattern(pattern, extra_check=_return_true, pass_number=1): + """ + Register an aten to inductor IR replacement pattern + """ + + def decorator(handler): + assert callable(handler) + for target in pattern.fns: + LoweringPatternEntry( + pattern=pattern, extra_check=extra_check, handler=handler + ).register(pass_number, target) + handler._inductor_lowering_function = True + return handler + + assert isinstance(pattern, CallFunction) + return decorator + + +register_pattern = register_lowering_pattern + + +def replace_matched_patterns(graph: torch.fx.Graph): + # the actual replacement work + for patterns in pass_patterns: + if not patterns: + continue + for node in reversed(graph.nodes): + if node.op == "call_function" and node.target in patterns: + for entry in patterns[node.target]: + if node._erased: + break + m = entry.pattern.match(node) + if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name: + log.warning(f"{node}{node.args} {m} {entry.pattern}") + if m and entry.extra_check(m): + entry.apply(m, graph, node) + counters["inductor"]["pattern_matcher_count"] += 1 + counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes) + + +def reorder_for_locality(graph: torch.fx.Graph): + def visit(other_node): + if ( + other_node.op == "call_function" + and other_node.target != operator.getitem + and all((n in seen_nodes) for n in other_node.users) + ): + # move node's producers right before it + node.prepend(other_node) + + seen_nodes = set() + for node in reversed(graph.nodes): + seen_nodes.add(node) + torch.fx.map_arg((node.args, node.kwargs), visit) + + +def fx_passes(gm: torch.fx.GraphModule): + if config.dce: + # has some issues with mutation in inference mode + gm.graph.eliminate_dead_code() + + if config.reordering: + # has some issues with mutation in inference mode + reorder_for_locality(gm.graph) + + if config.pattern_matcher: + replace_matched_patterns(gm.graph) + + gm.graph.lint() + + +################################################################################ +# Actual patterns below this point. +# Priority of patterns is: +# - later output nodes first +# - order patterns are defined in +################################################################################ + + +@register_lowering_pattern( + CallFunction( + aten.add, + CallFunction(aten.mm, Arg(), Arg()), + CallFunction(aten.mm, Arg(), Arg()), + ) +) +def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4): + return inductor.kernel.mm_plus_mm.tuned_mm_plus_mm(mat1, mat2, mat3, mat4) + + +@register_lowering_pattern( + CallFunction(aten.cat, ListOf(CallFunction(aten.mm, Arg(), Arg())), Arg()), +) +def cat_mm(match, inputs, dim): + def shape_of(a, b): + m, _ = a.get_size() + _, n = b.get_size() + return [m, n] + + return cat_tuned_op(match, inputs, dim, op=L[aten.mm], shape_of=shape_of) + + +@register_lowering_pattern( + CallFunction( + aten.cat, ListOf(CallFunction(aten.addmm, Arg(), Arg(), Arg())), Arg() + ), +) +def cat_addmm(match, inputs, dim): + def shape_of(bias, a, b): + m, _ = a.get_size() + _, n = b.get_size() + return [m, n] + + return cat_tuned_op(match, inputs, dim, op=L[aten.addmm], shape_of=shape_of) + + +def cat_tuned_op(match, inputs, dim, *, op, shape_of): + """ + Memory planning to remove cat. We can't use the stock memory + planner since autotuning matmauls needs to know the output layout. + """ + # TODO(jansel): rewrite this as a bmm? + if dim < 0: + dim += len(shape_of(*inputs[0])) + assert dim in (0, 1) + notdim = 1 - dim + + new_size = None + offsets_start = [] + offsets_end = [] + + # compute output sizes + for i in range(len(inputs)): + shape = shape_of(*inputs[i]) + if new_size is None: + new_size = shape + else: + new_size[notdim] = V.graph.sizevars.guard_equals( + shape[notdim], new_size[notdim] + ) + new_size[dim] += shape[dim] + offsets_start.append(new_size[dim] - shape[dim]) + offsets_end.append(new_size[dim]) + + dtype = functools.reduce( + torch.promote_types, [x.get_dtype() for x in itertools.chain(*inputs)] + ) + device = inputs[0][0].get_device() + kernel = ir.ConcatKernel( + name=None, + layout=ir.FixedLayout(device, dtype, new_size), + inputs=[], + ) + kernel_tensor = ir.TensorBox.create(kernel) + + for i in range(len(inputs)): + dst = ir.SliceView.create(kernel_tensor, dim, offsets_start[i], offsets_end[i]) + src = op(*inputs[i], layout=dst.get_layout()).data.data + assert isinstance(src, (ir.ExternKernelOut, ir.TemplateBuffer)) + src.layout = ir.AliasedLayout(dst) + kernel.inputs.append(src) + + kernel.name = V.graph.register_buffer(kernel) + kernel.inputs = ir.ConcatKernel.unwrap_storage(kernel.inputs) + return kernel_tensor + + +_cat_1 = CallFunction(aten.cat, Arg(), 1, _users=2) + + +@register_lowering_pattern( + CallFunction( + aten.cat, + [ + _cat_1, + CallFunction( + aten.slice, + CallFunction(aten.slice, _cat_1, 0, 0, 9223372036854775807), + 1, + 0, + KeywordArg("size"), + ), + ], + 1, + ) +) +def cat_slice_cat(match, cat_input, size, dim=1): + """ + This is an example of a more complex pattern where cat_1 is used + multiple times inside the pattern. We fold 2 calls to cat into one. + + Matches: + cat_1: f32[1024, 4077] = torch.ops.aten.cat.default([add_26, primals_217], 1) + slice_1: f32[1024, 4077] = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807) + slice_2: f32[1024, 19] = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19) + cat_2: f32[1024, 4096] = torch.ops.aten.cat.default([cat_1, slice_2], 1) + + + Rewrite to: + slice_2 = torch.ops.aten.slice.Tensor(add_26, 1, 0, 19) + cat_2 = torch.ops.aten.cat.default([add_26, primals_217, slice2], 1) + """ + first, *rest = cat_input + if V.graph.sizevars.maybe_guard_leq(size, first.get_size()[dim]): + # fold 2 cats into 1 cat + return L[aten.cat]( + [ + first, + *rest, + L[aten.slice](first, dim, 0, size), + ], + dim, + ) + else: + # don't expect to hit this case, just fall back + tmp = L[aten.cat](cat_input, dim) + return L[aten.cat]( + [ + tmp, + L[aten.slice](tmp, dim, 0, size), + ], + dim, + ) + + +@register_replacement_pattern( + CallFunction( + aten.add, + CallFunction(aten.mm, Arg(), Arg()), + KeywordArg("added"), + ), + pass_number=2, +) +@register_replacement_pattern( + CallFunction( + aten.add, + KeywordArg("added"), + CallFunction(aten.mm, Arg(), Arg()), + ), + pass_number=2, +) +def addmm(mat1, mat2, added): + return aten.addmm(added, mat1, mat2) + + +# This slows things down: +""" +@register_replacement_pattern( + CallFunction( + aten.add, + CallFunction(aten.bmm, Arg(), Arg()), + KeywordArg("added"), + ), + pass_number=3 +) +@register_replacement_pattern( + CallFunction( + aten.add, + KeywordArg("added"), + CallFunction(aten.bmm, Arg(), Arg()), + ), + pass_number=3 +) +def baddbmm(mat1, mat2, added): + return aten.baddbmm(added, mat1, mat2) +""" diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index 8fc9206c9ef19..1b216a67c2e04 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -128,6 +128,7 @@ def __getattr__(self, item): ops = Virtualized("ops", MockHandler) _graph = Virtualized("graph", NullHandler) +_fake_mode = Virtualized("fake_mode", NullHandler) _kernel = Virtualized("kernel", NullHandler) _debug = Virtualized("debug", NullHandler) @@ -140,6 +141,7 @@ class _V: set_ops_handler = ops._set_handler get_ops_handler = ops._get_handler set_graph_handler = _graph._set_handler + set_fake_mode = _fake_mode._set_handler set_kernel_handler = _kernel._set_handler set_debug_handler = _debug._set_handler @@ -153,6 +155,11 @@ def graph(self): """The graph currently being generated""" return _graph._get_handler() + @property + def fake_mode(self): + """The graph currently being generated""" + return _fake_mode._get_handler() + @property def kernel(self): """The kernel currently being generated""" diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 79666e935a8fa..109c0168a2219 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1005,7 +1005,11 @@ def gen_wrap_fn(self, func, args, kwargs): def wrap(e, device=None): nonlocal common_device nonlocal has_scalar_only_inputs - if isinstance(e, torch.Tensor) and not isinstance(e, FakeTensor): + if ( + isinstance(e, torch.Tensor) + and not isinstance(e, FakeTensor) + and converter is not None + ): if common_device is None: ( common_device,