Skip to content

Commit

Permalink
Fix/refactor dynamo onnxrt backend (pytorch#93818)
Browse files Browse the repository at this point in the history
Fixes pytorch#90352

Pull Request resolved: pytorch#93818
Approved by: https://github.com/voznesenskym
  • Loading branch information
jansel authored and pytorchmergebot committed Feb 3, 2023
1 parent d9870d7 commit a5ff400
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 187 deletions.
76 changes: 0 additions & 76 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,70 +805,6 @@ def try_script(model, example_inputs):
return None


def speedup_experiment_onnx(args, model_iter_fn, model, example_inputs):
"""
Measure baseline performance (without using TorchDynamo) of ONNXRT and TensorFlow.
Writes to ./baseline_onnx.csv
"""
if current_device == "cpu":
m_onnxrt = backends.onnxrt_cpu(
try_script(model, example_inputs), example_inputs
)
else:
m_onnxrt = backends.onnxrt_cuda(
try_script(model, example_inputs), example_inputs
)

if current_name != "timm_resnest":
m_onnx2tf = backends.onnx2tf(try_script(model, example_inputs), example_inputs)
else:
# this one takes 8+ hours to finish
m_onnx2tf = None

return baselines(
[
("eager", model),
("onnxrt", m_onnxrt),
("onnx2tf", m_onnx2tf),
],
model_iter_fn,
example_inputs,
args,
)


def speedup_experiment_trt(args, model_iter_fn, model, example_inputs):
"""
Measure baseline performance (without using TorchDynamo) of TensorRT.
Writes to ./baseline_trt.csv
"""
m_onnx2trt = backends.onnx2tensorrt(
try_script(model, example_inputs), example_inputs
)

m_torch2trt = backends.torch2trt(model, example_inputs)

if current_name != "opacus_cifar10":
m_fx2trt = backends.fx2trt(model, example_inputs)
else:
# fx2trt infinite loops on one model
m_fx2trt = None

return baselines(
[
("eager", model),
("onnx2trt", m_onnx2trt),
("torch2trt", m_torch2trt),
("fx2trt", m_fx2trt),
],
model_iter_fn,
example_inputs,
args,
)


def read_batch_size_from_file(args, filename, model_name):
batch_size = None
if os.path.exists("benchmarks"):
Expand Down Expand Up @@ -1780,12 +1716,6 @@ def get_example_inputs(self):
group.add_argument(
"--overhead", action="store_true", help=help(overhead_experiment)
)
group.add_argument(
"--speedup-onnx", action="store_true", help=help(speedup_experiment_onnx)
)
group.add_argument(
"--speedup-trt", action="store_true", help=help(speedup_experiment_trt)
)
group.add_argument(
"--speedup-dynamo-ts",
action="store_true",
Expand Down Expand Up @@ -2073,12 +2003,6 @@ def run(runner, args, original_dir=None):
optimize_ctx = torch._dynamo.optimize("inductor", nopython=args.nopython)
experiment = speedup_experiment
output_filename = "inductor.csv"
elif args.speedup_onnx:
experiment = speedup_experiment_onnx
output_filename = "baseline_onnx.csv"
elif args.speedup_trt:
experiment = speedup_experiment_trt
output_filename = "baseline_trt.csv"
elif args.speedup_dynamo_ts:
optimize_ctx = torch._dynamo.optimize("ts", nopython=args.nopython)
experiment = speedup_experiment
Expand Down
4 changes: 4 additions & 0 deletions test/dynamo/test_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def test_nvprims_nvfuser(self):
def test_nvprims_aten(self):
self._check_backend_works("nvprims_aten")

@unittest.skipIf(not has_onnxruntime(), "requires onnxruntime")
def test_onnxrt(self):
self._check_backend_works("onnxrt")


class NormalizeIRTests(torch._dynamo.test_case.TestCase):
def test_inplace_normalize(self):
Expand Down
37 changes: 37 additions & 0 deletions torch/_dynamo/backends/common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import functools
import logging

import torch
from torch._dynamo import eval_frame
from torch._dynamo.utils import counters
from torch._functorch.aot_autograd import aot_module_simplified
from torch._subclasses import FakeTensor
from torch.utils._python_dispatch import _disable_current_modes

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -70,3 +73,37 @@ def mem_efficient_fusion_kwargs(use_decomps):
kwargs["decompositions"] = default_decompositions

return kwargs


def fake_tensor_unsupported(fn):
"""
Decorator for backends that need real inputs. We swap out fake
tensors for zero tensors.
"""

def defake(x):
if not isinstance(x, FakeTensor):
return x
y = torch.empty_strided(
x.size(),
x.stride(),
dtype=x.dtype,
device=x.device,
requires_grad=x.requires_grad,
)
y.zero_()
return y

@functools.wraps(fn)
def wrapper(model, inputs, **kwargs):
with _disable_current_modes():
inputs = list(map(defake, inputs))
return fn(model, inputs, **kwargs)

return wrapper


def device_from_inputs(example_inputs) -> torch.device:
for x in example_inputs:
if hasattr(x, "device"):
return x.device
109 changes: 109 additions & 0 deletions torch/_dynamo/backends/onnxrt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import os
import tempfile

import torch
from .common import device_from_inputs, fake_tensor_unsupported
from .registry import register_backend

try:
import numpy as np

_np_dtype = {
torch.float16: np.float16,
torch.float32: np.float32,
torch.float64: np.float64,
torch.uint8: np.uint8,
torch.int8: np.int8,
torch.int16: np.int16,
torch.int32: np.int32,
torch.int64: np.longlong,
torch.bool: np.bool_,
}

except ImportError:
_np_dtype = None


def default_provider(device_type):
if "ONNXRT_PROVIDER" in os.environ:
return os.environ["ONNXRT_PROVIDER"]
return {
"cpu": "CPUExecutionProvider",
"cuda": "CUDAExecutionProvider",
# "TensorrtExecutionProvider" is another option
}[device_type]


@register_backend
@fake_tensor_unsupported
def onnxrt(gm, example_inputs, *, filename=None, provider=None):
if filename is None:
with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp:
return onnxrt(gm, example_inputs, filename=tmp.name)

import onnxruntime # type: ignore[import]

assert _np_dtype, "requires numpy"

device_type = device_from_inputs(example_inputs).type
example_outputs = gm(*example_inputs)
output_spec = [
(o.shape, o.dtype, o.layout, o.device, o.requires_grad) for o in example_outputs
]
input_names = [f"i{i}" for i in range(len(example_inputs))]
output_names = [f"o{x}" for x in range(len(example_outputs))]

torch.onnx.export(
torch.jit.script(gm),
example_inputs,
filename,
input_names=input_names,
output_names=output_names,
)
del example_inputs, example_outputs

if provider is None:
provider = default_provider(device_type)
assert provider in onnxruntime.get_available_providers()
session = onnxruntime.InferenceSession(filename, providers=[provider])

def _call(*initial_args):
binding = session.io_binding()
args = [a.contiguous() for a in initial_args]
for name, value in zip(input_names, args):
dev = value.device
binding.bind_input(
name,
dev.type,
dev.index or 0,
_np_dtype[value.dtype],
value.size(),
value.data_ptr(),
)
outputs = [
torch.empty(
shape,
dtype=dtype,
layout=layout,
device=device,
requires_grad=requires_grad,
)
for shape, dtype, layout, device, requires_grad in output_spec
]

for name, value in zip(output_names, outputs):
dev = value.device
binding.bind_output(
name,
dev.type,
dev.index or 0,
_np_dtype[value.dtype],
value.size(),
value.data_ptr(),
)
session.run_with_iobinding(binding)
if device_type == "cpu":
binding.copy_outputs_to_cpu()
return outputs

return _call
112 changes: 1 addition & 111 deletions torch/_dynamo/optimizations/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,116 +37,6 @@ def inner(model, example_inputs=None, **kwargs):
return register_backend(inner)


def onnxrt_common(subgraph, provider, onnx_filename=None):
import numpy as np # type: ignore[import]
import onnxruntime # type: ignore[import]

_np_dtype = {
torch.float16: np.float16,
torch.float32: np.float32,
torch.float64: np.float64,
torch.uint8: np.uint8,
torch.int8: np.int8,
torch.int16: np.int16,
torch.int32: np.int32,
torch.int64: np.longlong,
torch.bool: np.bool_,
}

assert provider in onnxruntime.get_available_providers()
session = onnxruntime.InferenceSession(
onnx_filename or subgraph.onnx_filename, providers=[provider]
)
input_names = subgraph.input_names
output_names = subgraph.output_names
create_outputs = subgraph.empty_outputs_factory()
is_cpu = subgraph.is_cpu

def _call(*initial_args):
binding = session.io_binding()
args = [a.contiguous() for a in initial_args]
for name, value in zip(input_names, args):
dev = value.device
binding.bind_input(
name,
dev.type,
dev.index or 0,
_np_dtype[value.dtype],
value.size(),
value.data_ptr(),
)
outputs = create_outputs()
for name, value in zip(output_names, outputs):
dev = value.device
binding.bind_output(
name,
dev.type,
dev.index or 0,
_np_dtype[value.dtype],
value.size(),
value.data_ptr(),
)
session.run_with_iobinding(binding)
if is_cpu:
binding.copy_outputs_to_cpu()
return outputs

return subgraph.wrap_returns(_call)


@create_backend
def onnxrt_cpu(subgraph):
return onnxrt_common(subgraph, provider="CPUExecutionProvider")


@create_backend
def onnxrt_cuda(subgraph):
return onnxrt_common(subgraph, provider="CUDAExecutionProvider")


@create_backend
def onnx2tensorrt(subgraph):
if subgraph.will_tensorrt_barf():
# TensorRT fails violently with an abort() on this
return None

return onnxrt_common(subgraph, provider="TensorrtExecutionProvider")


@create_backend
def onnxrt_cpu_numpy(subgraph, provider="CPUExecutionProvider"):
"""Alternate version that integrates via numpy"""
import onnxruntime

assert provider in onnxruntime.get_available_providers()
ort_session = onnxruntime.InferenceSession(
subgraph.onnx_filename, providers=[provider]
)

def to_numpy(x):
try:
return x.numpy()
except RuntimeError:
return x.detach().numpy()

def _call(*args):
res = ort_session.run(
None, {f"i{i}": to_numpy(arg) for i, arg in enumerate(args)}
)
res = [torch.from_numpy(x) for x in res]
return res

return subgraph.wrap_returns(_call)


@create_backend
def onnxrt(subgraph):
if subgraph.is_cuda:
return onnxrt_cuda(subgraph)
else:
return onnxrt_cpu(subgraph)


def _raise_timeout(signum, frame):
raise TimeoutError()

Expand Down Expand Up @@ -272,7 +162,7 @@ def tensorrt(subgraph):
# TensorRT fails violently with an abort() on this
return None

model = onnx2tensorrt(subgraph)
model = fx2trt(subgraph)
if model is None:
model = torch2trt(subgraph)
return model
Expand Down

0 comments on commit a5ff400

Please sign in to comment.