Skip to content

Commit

Permalink
Renamed passes to options in torch.compile (pytorch#94500)
Browse files Browse the repository at this point in the history
@jansel expressed a preference for this (as most of our options are *not* passes), and I agree. I also think that `fullgraph` could be changed, but I don't know what I'd change it to. I considered `strict`, but some folks objected to that.

Pull Request resolved: pytorch#94500
Approved by: https://github.com/voznesenskym, https://github.com/soumith, https://github.com/jansel
  • Loading branch information
Chillee authored and pytorchmergebot committed Feb 10, 2023
1 parent 59e8756 commit 3a12b16
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
2 changes: 1 addition & 1 deletion test/inductor/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_compile_api(self):
{"mode": "reduce-overhead"},
{"mode": "max-autotune"},
{
"passes": {
"options": {
"max-fusion-size": 128,
"unroll_reductions_threshold": 32,
"triton.cudagraphs": False,
Expand Down
30 changes: 15 additions & 15 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,19 +1319,19 @@ def compiled_with_cxx11_abi():
class _TorchCompileInductorWrapper:
compiler_name = "inductor"

def __init__(self, mode, passes, dynamic):
def __init__(self, mode, options, dynamic):
from torch._inductor.compile_fx import compile_fx

self.compile_fn = compile_fx
self._torchdynamo_orig_callable = compile_fx
self.config = dict()
self.apply_mode(mode)
self.apply_passes(passes)
self.apply_options(options)
if dynamic:
# cudagraphs conflicts with dynamic shapes
self.config["triton.cudagraphs"] = False
assert "triton.cudagraphs" not in (
passes or ()
options or ()
), "triton.cudagraphs does not support dynamic shapes"

def apply_mode(self, mode: Optional[str]):
Expand All @@ -1349,18 +1349,18 @@ def apply_mode(self, mode: Optional[str]):
f"Unrecognized mode={mode}, should be one of: default, reduce-overhead, max-autotune"
)

def apply_passes(self, passes: Optional[Dict[str, Any]]):
if not passes:
def apply_options(self, options: Optional[Dict[str, Any]]):
if not options:
return

from torch._inductor import config
current_config: Dict[str, Any] = config.to_dict() # type: ignore[attr-defined]

for key, val in passes.items():
for key, val in options.items():
attr_name = key.replace("-", "_")
if attr_name not in current_config:
raise RuntimeError(
f"Unexpected optimization pass {key}, known passes are {list(current_config.keys())}"
f"Unexpected optimization option {key}, known options are {list(current_config.keys())}"
)
if type(val) is not type(current_config[attr_name]):
val_type_str = type(val).__name__
Expand All @@ -1379,7 +1379,7 @@ def compile(model: Optional[Callable] = None, *,
dynamic: builtins.bool = False,
backend: Union[str, Callable] = "inductor",
mode: Union[str, None] = None,
passes: Optional[Dict[str, Union[str, builtins.int, builtins.bool]]] = None,
options: Optional[Dict[str, Union[str, builtins.int, builtins.bool]]] = None,
disable: builtins.bool = False) -> Callable:
"""
Optimizes given model/function using TorchDynamo and specified backend.
Expand All @@ -1390,12 +1390,12 @@ def compile(model: Optional[Callable] = None, *,
dynamic (bool): Use dynamic shape tracing
backend (str or Callable): backend to be used
mode (str): Can be either "default", "reduce-overhead" or "max-autotune"
passes (dict): A dictionary of options to pass to the backend.
options (dict): A dictionary of options to pass to the backend.
disable (bool): Turn torch.compile() into a no-op for testing
Example::
@torch.compile(passes={"matmul-padding": True}, fullgraph=True)
@torch.compile(options={"matmul-padding": True}, fullgraph=True)
def foo(x):
return torch.sin(x) + torch.cos(x)
Expand All @@ -1411,17 +1411,17 @@ def fn(model: Callable):
dynamic=dynamic,
backend=backend,
mode=mode,
passes=passes,
options=options,
disable=disable)
return fn

import torch._dynamo
if mode is not None and passes is not None:
raise RuntimeError("Either mode or passes can be specified, but both can't be specified at the same time.")
if mode is None and passes is None:
if mode is not None and options is not None:
raise RuntimeError("Either mode or options can be specified, but both can't be specified at the same time.")
if mode is None and options is None:
mode = "default"
if backend == "inductor":
backend = _TorchCompileInductorWrapper(mode, passes, dynamic)
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
return torch._dynamo.optimize(backend=backend, nopython=fullgraph, dynamic=dynamic, disable=disable)(model)


Expand Down

0 comments on commit 3a12b16

Please sign in to comment.