Skip to content

Commit

Permalink
Add Config to Skip Cpp Codegen, Enable in FBCode (pytorch#97204)
Browse files Browse the repository at this point in the history
  • Loading branch information
eellison authored and pytorchmergebot committed Mar 28, 2023
1 parent c0e0fbb commit 6854fd7
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 18 deletions.
28 changes: 27 additions & 1 deletion test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from torch._inductor.virtualized import V
from torch.fx.experimental.proxy_tensor import make_fx
from torch.nn import functional as F
from torch.testing import make_tensor
from torch.testing import FileCheck, make_tensor
from torch.testing._internal.common_dtype import all_types
from torch.testing._internal.common_utils import (
IS_CI,
Expand Down Expand Up @@ -6386,6 +6386,32 @@ def fn(x):
- metrics.generated_cpp_vec_kernel_count
) == 0

def test_skip_cpp_codegen(self):
with config.patch({"disable_cpp_codegen": True}):
inps = torch.ones([20]), torch.rand([20])

def f(x, y):
return x + y + torch.tensor(1)

f_opt = torch.compile()(f)

code = run_and_get_cpp_code(f_opt, (inps[0], inps[1]))
FileCheck().check_not("void kernel").run(code)

self.assertEqual(
f(*inps),
f_opt(*inps),
)

# constant needs to be propagated on fallback
def f(x):
return x[torch.tensor(1) :] * 2

f_opt = torch.compile()(f)
code = run_and_get_cpp_code(f_opt, (inps[0],))
FileCheck().check_not("void kernel").run(code)
self.assertEqual(f_opt(inps[0]), f(inps[0]))

def test_redundant_to_node_elimination_bf16(self):
def fn(x, y):
res = x + y
Expand Down
18 changes: 14 additions & 4 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@

import torch.fx
import torch.utils._pytree as pytree

from torch._dynamo import logging as dynamo_logging, utils as dynamo_utils
from torch._dynamo.utils import fake_mode_from_tensors
from torch._functorch.aot_autograd import make_boxed_func
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.passes.fake_tensor_prop import FakeTensorProp

from .._dynamo.backends.common import aot_autograd
from ..fx.graph import _PyTreeCodeGen
from . import config, metrics, overrides, pattern_matcher
Expand Down Expand Up @@ -164,14 +165,23 @@ def compile_fx_inner(

shape_env = _shape_env_from_inputs(example_inputs)

fake_mode = fake_mode_from_tensors(
example_inputs
) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
fake_mode = fake_mode_from_tensors(example_inputs)
if not fake_mode:
fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs)
else:
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(
*example_inputs
)
# pattern matcher passes might not preserve striding information
# on node.meta["val"]. if in the future we rely on these being
# correct we will need to fix.

with V.set_fake_mode(fake_mode):
pattern_matcher.fx_passes(gm)
V.debug.fx_graph_transformed(gm, example_inputs)

with V.set_fake_mode(fake_mode):
graph = GraphLowering(
gm,
shape_env=shape_env,
Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def decide_compile_threads():
profile_bandwidth = _profile_var != ""
profile_bandwidth_regex = "" if _profile_var == "1" else _profile_var

disable_cpp_codegen = is_fbcode()


# config specific to codegen/cpp.pp
class cpp:
Expand Down
19 changes: 17 additions & 2 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@
from .ir import Constant, FixedLayout, InputBuffer, Pointwise, Reduction, TensorBox
from .lowering import (
FALLBACK_ALLOW_LIST,
fallback_handler,
fallback_node_due_to_unsupported_type,
layout_constraints,
lowerings,
make_fallback,
needs_realized_inputs,
unsupported_output_tensor,
)
from .sizevars import SizeVarAllocator
from .utils import (
Expand Down Expand Up @@ -373,6 +376,10 @@ def call_function(self, target, args, kwargs):
def get_attr(self, target, args, kwargs):
# this is a constant
value = getattr(self.module, target)

if unsupported_output_tensor(value):
return self.add_tensor_constant(value)

with no_dispatch():
if value.shape == ():
return Constant(value.item(), value.dtype, value.device)
Expand Down Expand Up @@ -440,8 +447,11 @@ def run_node(self, n: torch.fx.Node):
args, kwargs = self.fetch_args_kwargs_from_env(n)
origins |= gather_origins(args, kwargs)
with ir.IRNode.current_origins(origins):
if n.op == "call_function" and n.target in layout_constraints:
args, kwargs = self.fetch_args_kwargs_from_env(n)
if n.op == "call_function" and fallback_node_due_to_unsupported_type(n):
result = fallback_handler(n.target, add_to_fallback_set=False)(
*args, **kwargs
)
elif n.op == "call_function" and n.target in layout_constraints:
args, kwargs = layout_constraints[n.target](n, *args, **kwargs)
result = self.call_function(n.target, args, kwargs)
elif is_magic_method(n.target):
Expand Down Expand Up @@ -530,6 +540,10 @@ def run_node(self, n: torch.fx.Node):

return result

def check_cpp_codegen_disabled(self):
if config.disable_cpp_codegen:
self.disable_cpp_wrapper("cpp codegen disabled")

def check_platform(self):
if sys.platform != "linux":
self.disable_cpp_wrapper("platform not linux")
Expand All @@ -551,6 +565,7 @@ def check_constant_for_cpp_buffer(self):
self.disable_cpp_wrapper("Constants")

def check_cpp_wrapper(self):
self.check_cpp_codegen_disabled()
self.check_platform()
self.check_device_for_cpp_buffer()
self.check_input_for_cpp_buffer()
Expand Down
8 changes: 7 additions & 1 deletion torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2509,8 +2509,14 @@ def unflatten_args(new_tensor_args, new_non_tensor_args):
# TODO(jansel): replace this with dynamic shape formulas
example_args = []

# We need to retain the constant values of fake tensors that we originally
# propagated the graph with, because for some operators running without a
# constant would trigger an error / DataDependentException
for x in tensor_args:
example_args.append(ir_node_to_tensor(x, guard_shape=True))
if x.get_name() in V.graph.constants:
example_args.append(V.graph.constants[x.get_name()])
else:
example_args.append(ir_node_to_tensor(x, guard_shape=True))

new_args, new_kwargs = unflatten_args(example_args, non_tensor_args)
example_output = kernel(*new_args, **new_kwargs)
Expand Down
36 changes: 34 additions & 2 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
type_to_dtype,
)
from torch.fx.experimental.symbolic_shapes import magic_methods, method_to_operator
from torch.utils._pytree import tree_flatten
from .._dynamo.utils import import_submodule

from . import config, ir, overrides, test_operators # NOQA: F401
Expand Down Expand Up @@ -1027,8 +1028,9 @@ def mkl_packed_linear(
register_onednn_fusion_ops()


def fallback_handler(kernel):
fallbacks.add(kernel)
def fallback_handler(kernel, add_to_fallback_set=True):
if add_to_fallback_set:
fallbacks.add(kernel)

def handler(*args, **kwargs):
return pytree.tree_map(
Expand All @@ -1038,6 +1040,36 @@ def handler(*args, **kwargs):
return handler


def unsupported_output_tensor(t: torch._subclasses.FakeTensor):
return t.is_cpu and config.disable_cpp_codegen


def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=True):
def check_skip_condition(node, is_output):
if not isinstance(node, torch.fx.Node):
return False

if "val" not in node.meta:
return False

for meta in tree_flatten(node.meta["val"])[0]:
if not isinstance(meta, torch._subclasses.FakeTensor):
continue

if is_output:
if unsupported_output_tensor(meta):
return True

return False

# only skip codegen if there is a cpu output, not input
for arg in tree_flatten((node.args, node.kwargs))[0]:
if check_skip_condition(arg, is_output=(False or not allow_cpu_inputs)):
return True

return check_skip_condition(node, is_output=True)


def make_fallback(kernel, layout_constraint=None, warn=True):
assert (
kernel not in decompositions
Expand Down
9 changes: 8 additions & 1 deletion torch/_inductor/pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch.fx.immutable_collections import immutable_dict, immutable_list

from . import config, ir
from .lowering import lowerings as L
from .lowering import fallback_node_due_to_unsupported_type, lowerings as L
from .virtualized import V

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -245,6 +245,7 @@ def _match(self, node: List[torch.fx.Node], ctx: MatchContext):
return m.bundle()


# First pass_patterns[0] are applied, then [1], then [2]
pass_patterns = [
defaultdict(list),
defaultdict(list),
Expand Down Expand Up @@ -369,6 +370,12 @@ def replace_matched_patterns(graph: torch.fx.Graph):
continue
for node in reversed(graph.nodes):
if node.op == "call_function" and node.target in patterns:
# conservatively not applying pattern for cpu input,
# since some of the patterns induce codegen and split nodes.
# Note: we will only skip cpu compute if disable_cpp_codegen=True
if fallback_node_due_to_unsupported_type(node, allow_cpu_inputs=False):
continue

for entry in patterns[node.target]:
if node._erased:
break
Expand Down
9 changes: 8 additions & 1 deletion torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,13 @@ def dispatch(self, func, types, args=(), kwargs=None):
and not has_symbolic_sizes
and not flat_arg_fake_tensors
):
out = func(*args, **kwargs)
assert all(
t.constant is not None for t in flat_arg_fake_tensors
), "f{func} should not have fake inputs without constants"
const_args, const_kwargs = pytree.tree_map_only(
FakeTensor, lambda t: t.constant, (args, kwargs)
)
out = func(*const_args, **const_kwargs)
if self.may_turn_const(out):
# NB: not in_kernel_invocation_manager because we're doing real
# compute here
Expand All @@ -1131,6 +1137,7 @@ def dispatch(self, func, types, args=(), kwargs=None):
assert (
len(kwargs) == 0 and len(args) == 1 and type(args[0]) is torch.Tensor
), f"{args} {kwargs}"

return converter(self, args[0])

args, kwargs = self.validate_and_convert_non_fake_tensors(
Expand Down
4 changes: 2 additions & 2 deletions torch/fx/experimental/proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def set_meta(proxy, val):
elif isinstance(val, py_sym_types):
proxy.node.meta['val'] = val
elif isinstance(val, (list, tuple)):
if all(isinstance(x, FakeTensor) for x in val):
proxy.node.meta['val'] = [snapshot_fake(x) for x in val]
if any(isinstance(x, FakeTensor) for x in val):
proxy.node.meta['val'] = [snapshot_fake(x) if isinstance(x, FakeTensor) else None for x in val]
elif isinstance(val, torch.Tensor):
if not val.is_sparse:
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val)
Expand Down
29 changes: 25 additions & 4 deletions torch/fx/passes/fake_tensor_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import torch.fx
from torch.fx import Node
from torch.fx._compatibility import compatibility
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
from torch.fx.experimental.proxy_tensor import py_sym_types, snapshot_fake
from torch.fx.node import map_aggregate

__all__ = ['FakeTensorProp']

Expand All @@ -29,10 +31,29 @@ def __init__(self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode]

def run_node(self, n: Node):
result = super().run_node(n)
n.meta['val'] = result

def extract_val(obj):
if isinstance(obj, FakeTensor):
return snapshot_fake(obj)
elif isinstance(obj, torch.Tensor):
return snapshot_fake(self._mode.from_tensor(obj))
elif isinstance(obj, py_sym_types):
return obj
else:
return None

meta = map_aggregate(result, extract_val)
if meta is not None:
n.meta['val'] = meta
return result

def propagate(self, *args):
fake_args = [
self._mode.from_tensor(a) if isinstance(a, torch.Tensor) else a
for a in args
]
return self.propagate_dont_convert_inputs(*fake_args)

def propagate_dont_convert_inputs(self, *args):
with self._mode:
fake_args = [self._mode.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
return super().run(*fake_args)
return super().run(*args)

0 comments on commit 6854fd7

Please sign in to comment.