Skip to content

Commit

Permalink
Remove many untested dynamo backends (pytorch#93382)
Browse files Browse the repository at this point in the history
  • Loading branch information
jansel authored and pytorchmergebot committed Feb 2, 2023
1 parent 653dc73 commit 569f2e3
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 170 deletions.
13 changes: 13 additions & 0 deletions test/dynamo/test_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,19 @@ def test_ipex_bf16(self):
self.assertTrue(same(r1, r2.float(), tol=0.1))
self.assertEqual(r2.dtype, torch.bfloat16)

def _check_backend_works(self, backend):
model = Conv_Bn_Relu(3, 32, kernel_size=3, stride=1).eval()
input = torch.randn(8, 3, 64, 64)
r1 = model(input)
r2 = torch.compile(model, backend=backend)(input)
self.assertTrue(same(r1, r2.float(), tol=0.01))

def test_eager(self):
self._check_backend_works("eager")

def test_torchscript(self):
self._check_backend_works("ts")


class NormalizeIRTests(torch._dynamo.test_case.TestCase):
@unittest.skipIf(not has_functorch(), "requires functorch")
Expand Down
4 changes: 2 additions & 2 deletions test/dynamo/test_verify_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,11 @@ def compiler_fn(graph, example_inputs):
self.assertEqual(r1.device, r2.device)
self.assertEqual(r1.device, r3.device)

def test_nnc(self):
def test_torchscript(self):
s = Seq()
i = torch.randn(10)
r1 = s(i)
opt_s = torch._dynamo.optimize("nnc")(s)
opt_s = torch._dynamo.optimize("ts")(s)
r2 = opt_s(i)
self.assertTrue(same(r1, r2))

Expand Down
194 changes: 27 additions & 167 deletions torch/_dynamo/optimizations/backends.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,41 @@
import copy
import functools
import io
import logging
import os
import subprocess
import tempfile

from typing import Dict
from typing import Dict, Optional

import torch
from ..output_graph import CompilerFn

from ..utils import identity
from .subgraph import SubGraph

log = logging.getLogger(__name__)
BACKENDS: Dict[str, CompilerFn] = dict()


def register_backend(fn):
@functools.wraps(fn)
def inner(gm, example_inputs, **kwargs):
return fn(gm, example_inputs, **kwargs)
def register_backend(compiler_fn: CompilerFn = None, name: Optional[str] = None):
"""
Decorator to add a given compiler to the BACKENDS registry to allow
calling `torch.compile` with string shorthand:
BACKENDS[fn.__name__] = inner
return inner
torch.compile(..., backend="name")
Note: for projects not imported by default, it might be easier to
pass a function directly as a backend and not use this:
torch.compile(..., backend=compiler_fn)
Args:
compiler_fn: callable taking a FX graph and fake tensor inputs
name: Optional name, defaults to `compiler_fn.__name__`
"""
if compiler_fn is None:
# @register_backend(name="") syntax
return functools.partial(register_backend, name=name)
BACKENDS[name or compiler_fn.__name__] = compiler_fn
return compiler_fn


def create_backend(fn):
Expand Down Expand Up @@ -61,75 +72,14 @@ def inductor(*args, **kwargs):
return compile_fx(*args, **kwargs)


@create_backend
def eager(subgraph):
return subgraph.model


@create_backend
def ts(subgraph):
return subgraph.scripted


def reload_jit_model(subgraph, opt_fn=identity):
tmp = io.BytesIO()
torch.jit.save(subgraph.scripted, tmp)
tmp.seek(0)
model = torch.jit.load(tmp)
model = opt_fn(model)
# populate cache
for _ in range(3):
model(*subgraph.example_inputs)
return model


def reload_jit_model_ofi(subgraph):
return reload_jit_model(subgraph, torch.jit.optimize_for_inference)


@create_backend
def nnc(subgraph):
with torch.jit.fuser("fuser1"):
return reload_jit_model(subgraph)


@create_backend
def nnc_ofi(subgraph):
with torch.jit.fuser("fuser1"):
return reload_jit_model_ofi(subgraph)


@create_backend
def ts_nvfuser(subgraph):
with torch.jit.fuser("fuser2"):
return reload_jit_model(subgraph)


@create_backend
def ts_nvfuser_ofi(subgraph):
with torch.jit.fuser("fuser2"):
return reload_jit_model_ofi(subgraph)


@create_backend
def onednn(subgraph):
with torch.jit.fuser("fuser3"):
return reload_jit_model(subgraph)
@register_backend
def eager(gm, fake_tensor_inputs):
return gm


@create_backend
def ofi(subgraph):
return torch.jit.optimize_for_inference(subgraph.scripted)


@create_backend
def static_runtime(subgraph):
scripted = subgraph.scripted
if hasattr(scripted, "_c"):
static_module = torch._C._jit_to_static_module(scripted._c)
else:
static_module = torch._C._jit_to_static_module(scripted.graph)
return subgraph.wrap_returns(static_module)
@register_backend(name="ts")
def torchscript(gm, fake_tensor_inputs):
return torch.jit.script(gm)


def onnxrt_common(subgraph, provider, onnx_filename=None):
Expand Down Expand Up @@ -242,70 +192,6 @@ def onnxrt(subgraph):
return onnxrt_cpu(subgraph)


@functools.lru_cache(None)
def _init_tensorflow():
import tensorflow as tf # type: ignore[import]

# prevent tensorflow from eating all the GPU memory
gpus = tf.config.list_physical_devices("GPU")
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
return tf


@create_backend
def onnx2tf(subgraph):
import onnx # type: ignore[import]
from onnx_tf.backend import prepare # type: ignore[import]

tf = _init_tensorflow()
filename = subgraph.filename("tensorflow")
input_names = subgraph.input_names
output_names = subgraph.output_names
device = "/CPU:0" if subgraph.is_cpu else f"/GPU:{subgraph.device_index}"
with tf.device(device):
if not os.path.exists(filename):
prepare(onnx.load(subgraph.onnx_filename)).export_graph(filename)
tf_module = tf.saved_model.load(filename)
tf_module = tf.function(tf_module, jit_compile=True)

def run(*i_args):
args = [a.contiguous() for a in i_args]
with tf.device(device):
outs = tf_module(
**{
name: tf.experimental.dlpack.from_dlpack(
torch.utils.dlpack.to_dlpack(args[idx])
)
for idx, name in enumerate(input_names)
}
)
return [
torch.utils.dlpack.from_dlpack(
tf.experimental.dlpack.to_dlpack(outs[name])
)
for name in output_names
]

return subgraph.wrap_returns(run)


@create_backend
def taso(subgraph):
taso_filename = subgraph.filename("taso")
subprocess.check_call(
[
os.path.expanduser("~/conda/envs/taso/bin/python"),
"-c",
"import taso,onnx; onnx.save(taso.export_onnx(taso.optimize("
f"taso.load_onnx('{subgraph.onnx_filename}'))), '{taso_filename}')",
]
)
return onnxrt_common(
subgraph, provider="CUDAExecutionProvider", onnx_filename=taso_filename
)


@create_backend
def ipex(subgraph, **kwargs):
import intel_extension_for_pytorch as ipex # type: ignore[import]
Expand Down Expand Up @@ -466,32 +352,6 @@ def cudagraphs(subgraph):
return subgraph.wrap_returns(cudagraphs_inner(model, inputs))


@create_backend
def cudagraphs_ts(subgraph):
assert subgraph.is_cuda
model = subgraph.scripted
inputs = subgraph.example_inputs

# warmup
for _ in range(3):
model(*inputs)

return subgraph.wrap_returns(cudagraphs_inner(model, inputs))


@create_backend
def cudagraphs_ts_ofi(subgraph):
assert subgraph.is_cuda
model = torch.jit.optimize_for_inference(torch.jit.freeze(subgraph.scripted))
inputs = subgraph.example_inputs

# warmup
for _ in range(3):
model(*inputs)

return subgraph.wrap_returns(cudagraphs_inner(model, inputs))


def cudagraphs_inner(model, inputs, copy_outputs=True):
assert isinstance(inputs, (list, tuple))
static_inputs = [torch.zeros_like(x) for x in inputs]
Expand Down
4 changes: 3 additions & 1 deletion torch/_dynamo/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def tearDownClass(cls):
def setUpClass(cls):
super().setUpClass()
cls._exit_stack = contextlib.ExitStack()
cls._exit_stack.enter_context(config.patch(raise_on_ctx_manager_usage=True))
cls._exit_stack.enter_context(
config.patch(raise_on_ctx_manager_usage=True, suppress_errors=False),
)

def setUp(self):
super().setUp()
Expand Down

0 comments on commit 569f2e3

Please sign in to comment.