Skip to content

feat: Add support for TorchTensorRTModule in Dynamo [1 / x] #2003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/fx/fx2trt_example_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting
from torch_tensorrt import TRTModuleNext as TRTModule, Device
from torch_tensorrt.dynamo._TorchTensorRTModule import (
TorchTensorRTModule as TRTModule,
Device,
)

# The purpose of this example is to demonstrate the overall flow of lowering a PyTorch
# model to TensorRT via FX with existing FX based tooling. The general lowering flow
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def _find_lib(name, paths):
from torch_tensorrt import logging
from torch_tensorrt._Input import Input
from torch_tensorrt._Device import Device
from torch_tensorrt._TRTModuleNext import TRTModuleNext

from torch_tensorrt import fx

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
from operator import truediv
from typing import Any, List, Sequence, Tuple
from typing import Any, List, Tuple

import torch
from torch_tensorrt import _C
Expand All @@ -9,8 +8,8 @@
logger = logging.getLogger(__name__)


class TRTModuleNext(torch.nn.Module):
"""TRTModuleNext is a PyTorch module which encompasses an arbitrary TensorRT Engine.
class TorchTensorRTModule(torch.nn.Module):
"""TorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine.

This module is backed by the Torch-TensorRT runtime and is fully compatibile with both
FX / Python deployments (just ``import torch_tensorrt`` as part of the application) as
Expand All @@ -20,7 +19,7 @@ class TRTModuleNext(torch.nn.Module):
The forward function is simpily forward(*args: torch.Tensor) -> Tuple[torch.Tensor] where
the internal implementation is ``return Tuple(torch.ops.tensorrt.execute_engine(list(inputs), self.engine))``

> Note: TRTModuleNext only supports engines built with explict batch
> Note: TorchTensorRTModule only supports engines built with explict batch

Attributes:
name (str): Name of module (for easier debugging)
Expand All @@ -37,7 +36,7 @@ def __init__(
output_binding_names: List[str] = [],
target_device: Device = Device._current_device(),
):
"""__init__ method for torch_tensorrt.TRTModuleNext
"""__init__ method for torch_tensorrt.TorchTensorRTModule

Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
a PyTorch ``torch.nn.Module`` around it.
Expand Down Expand Up @@ -70,10 +69,7 @@ def __init__(
)

"""
logger.warning(
"TRTModuleNext should be considered experimental stability, APIs are subject to change. Note: TRTModuleNext only supports engines built with explict batch"
)
super(TRTModuleNext, self).__init__()
super(TorchTensorRTModule, self).__init__()

if not isinstance(serialized_engine, bytearray):
ValueError("Expected serialized engine as bytearray")
Expand All @@ -89,8 +85,8 @@ def __init__(
self.name + "_engine" if self.name != "" else "tensorrt_engine",
target_device._to_serialized_rt_device(),
serialized_engine,
TRTModuleNext._pack_binding_names(self.input_binding_names),
TRTModuleNext._pack_binding_names(self.output_binding_names),
TorchTensorRTModule._pack_binding_names(self.input_binding_names),
TorchTensorRTModule._pack_binding_names(self.output_binding_names),
]
)
else:
Expand Down Expand Up @@ -154,7 +150,7 @@ def is_non_tensor(i: Tuple[Any, bool]) -> bool:

non_tensors = [i[0] for i in filter(zip(inputs, types), is_non_tensor)]
raise RuntimeError(
f"TRTModuleNext expects a flattened list of tensors as input, found non tensors: {non_tensors}"
f"TorchTensorRTModule expects a flattened list of tensors as input, found non tensors: {non_tensors}"
)

outputs = torch.ops.tensorrt.execute_engine(list(inputs), self.engine)
Expand Down
30 changes: 29 additions & 1 deletion py/torch_tensorrt/dynamo/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch_tensorrt
from functools import partial

from typing import Any, Sequence
from typing import Any, Optional, Sequence
from torch_tensorrt import EngineCapability, Device
from torch_tensorrt.fx.utils import LowerPrecision

Expand All @@ -16,6 +16,10 @@
WORKSPACE_SIZE,
MIN_BLOCK_SIZE,
PASS_THROUGH_BUILD_FAILURES,
MAX_AUX_STREAMS,
VERSION_COMPATIBLE,
OPTIMIZATION_LEVEL,
USE_PYTHON_RUNTIME,
)


Expand Down Expand Up @@ -45,6 +49,10 @@ def compile(
torch_executed_ops=[],
torch_executed_modules=[],
pass_through_build_failures=PASS_THROUGH_BUILD_FAILURES,
max_aux_streams=MAX_AUX_STREAMS,
version_compatible=VERSION_COMPATIBLE,
optimization_level=OPTIMIZATION_LEVEL,
use_python_runtime=USE_PYTHON_RUNTIME,
**kwargs,
):
if debug:
Expand Down Expand Up @@ -91,6 +99,10 @@ def compile(
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
pass_through_build_failures=pass_through_build_failures,
max_aux_streams=max_aux_streams,
version_compatible=version_compatible,
optimization_level=optimization_level,
use_python_runtime=use_python_runtime,
**kwargs,
)

Expand All @@ -114,6 +126,10 @@ def create_backend(
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Sequence[str] = set(),
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
version_compatible: bool = VERSION_COMPATIBLE,
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME,
**kwargs,
):
"""Create torch.compile backend given specified arguments
Expand All @@ -125,6 +141,13 @@ def create_backend(
min_block_size: Minimum number of operators per TRT-Engine Block
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False)
max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
version_compatible: Provide version forward-compatibility for engine plan files
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
searching for more optimization options. TRT defaults to 3
use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
argument as None
Returns:
Backend for torch.compile
"""
Expand All @@ -136,4 +159,9 @@ def create_backend(
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
pass_through_build_failures=pass_through_build_failures,
max_aux_streams=max_aux_streams,
version_compatible=version_compatible,
optimization_level=optimization_level,
use_python_runtime=use_python_runtime,
**kwargs,
)
4 changes: 4 additions & 0 deletions py/torch_tensorrt/dynamo/backend/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@
WORKSPACE_SIZE = 0
MIN_BLOCK_SIZE = 5
PASS_THROUGH_BUILD_FAILURES = False
MAX_AUX_STREAMS = None
VERSION_COMPATIBLE = False
OPTIMIZATION_LEVEL = None
USE_PYTHON_RUNTIME = None
12 changes: 10 additions & 2 deletions py/torch_tensorrt/dynamo/backend/_settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Sequence
from typing import Optional, Sequence

from torch_tensorrt.fx.utils import LowerPrecision
from torch_tensorrt.dynamo.backend._defaults import (
Expand All @@ -8,14 +8,22 @@
WORKSPACE_SIZE,
MIN_BLOCK_SIZE,
PASS_THROUGH_BUILD_FAILURES,
MAX_AUX_STREAMS,
VERSION_COMPATIBLE,
OPTIMIZATION_LEVEL,
USE_PYTHON_RUNTIME,
)


@dataclass(frozen=True)
@dataclass
class CompilationSettings:
precision: LowerPrecision = PRECISION
debug: bool = DEBUG
workspace_size: int = WORKSPACE_SIZE
min_block_size: int = MIN_BLOCK_SIZE
torch_executed_ops: Sequence[str] = field(default_factory=set)
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES
max_aux_streams: Optional[int] = MAX_AUX_STREAMS
version_compatible: bool = VERSION_COMPATIBLE
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def _compile_module(
submodule,
submodule_inputs,
settings=settings,
name=name,
)

trt_modules[name] = trt_mod
Expand Down
33 changes: 26 additions & 7 deletions py/torch_tensorrt/dynamo/backend/conversion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Sequence, Union
import torch
import io
from torch_tensorrt.fx.trt_module import TRTModule
from torch_tensorrt import TRTModuleNext
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
from torch_tensorrt.dynamo.fx_ts_compat.fx2trt import (
InputTensorSpec,
Expand All @@ -15,12 +15,14 @@ def convert_module(
module: torch.fx.GraphModule,
inputs: Sequence[torch.Tensor],
settings: CompilationSettings = CompilationSettings(),
) -> Union[TRTModuleNext, TRTModule]:
name: str = "",
):
"""Convert an FX module to a TRT module
Args:
module: FX GraphModule to convert
inputs: Sequence of Tensors representing inputs to the module
settings: Compilation settings
name: TRT engine name
Returns:
TRTModule or TRTModuleNext
"""
Expand Down Expand Up @@ -48,10 +50,27 @@ def convert_module(
if settings.debug
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
),
max_aux_streams=settings.max_aux_streams,
version_compatible=settings.version_compatible,
optimization_level=settings.optimization_level,
)

return TRTModule(
engine=interpreter_result.engine,
input_names=interpreter_result.input_names,
output_names=interpreter_result.output_names,
)
if settings.use_python_runtime:
return TRTModule(
engine=interpreter_result.engine,
input_names=interpreter_result.input_names,
output_names=interpreter_result.output_names,
)

else:
from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule

with io.BytesIO() as engine_bytes:
engine_bytes.write(interpreter_result.engine.serialize())
engine_str = engine_bytes.getvalue()
return TorchTensorRTModule(
serialized_engine=engine_str,
name=name,
input_binding_names=interpreter_result.input_names,
output_binding_names=interpreter_result.output_names,
)
Loading