Skip to content

Commit 7df445e

Browse files
committed
feat: Add support for TorchTensorRTModule in Dynamo
- Rename `TRTModuleNext` to `TorchTensorRTModule` across the repository, and move the source directory to `dynamo` - Update imports across the repository - Refactor `convert_module` code to support conversion to a `TorchTensorRTModule` - Add tests for `TorchTensorRTModule` functionality in Dynamo
1 parent 81d488a commit 7df445e

18 files changed

+246
-42
lines changed

examples/fx/fx2trt_example_next.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
99
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter
1010
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting
11-
from torch_tensorrt import TRTModuleNext as TRTModule, Device
11+
from torch_tensorrt import TorchTensorRTModule as TRTModule, Device
1212

1313
# The purpose of this example is to demonstrate the overall flow of lowering a PyTorch
1414
# model to TensorRT via FX with existing FX based tooling. The general lowering flow

py/torch_tensorrt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _find_lib(name, paths):
9191
from torch_tensorrt import logging
9292
from torch_tensorrt._Input import Input
9393
from torch_tensorrt._Device import Device
94-
from torch_tensorrt._TRTModuleNext import TRTModuleNext
94+
from torch_tensorrt.dynamo import TorchTensorRTModule
9595

9696
from torch_tensorrt import fx
9797

py/torch_tensorrt/_TRTModuleNext.py renamed to py/torch_tensorrt/dynamo/_TorchTensorRTModule.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
2-
from operator import truediv
3-
from typing import Any, List, Sequence, Tuple
2+
from typing import Any, List, Tuple
43

54
import torch
65
from torch_tensorrt import _C
@@ -9,8 +8,8 @@
98
logger = logging.getLogger(__name__)
109

1110

12-
class TRTModuleNext(torch.nn.Module):
13-
"""TRTModuleNext is a PyTorch module which encompasses an arbitrary TensorRT Engine.
11+
class TorchTensorRTModule(torch.nn.Module):
12+
"""TorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine.
1413
1514
This module is backed by the Torch-TensorRT runtime and is fully compatibile with both
1615
FX / Python deployments (just ``import torch_tensorrt`` as part of the application) as
@@ -20,7 +19,7 @@ class TRTModuleNext(torch.nn.Module):
2019
The forward function is simpily forward(*args: torch.Tensor) -> Tuple[torch.Tensor] where
2120
the internal implementation is ``return Tuple(torch.ops.tensorrt.execute_engine(list(inputs), self.engine))``
2221
23-
> Note: TRTModuleNext only supports engines built with explict batch
22+
> Note: TorchTensorRTModule only supports engines built with explict batch
2423
2524
Attributes:
2625
name (str): Name of module (for easier debugging)
@@ -37,7 +36,7 @@ def __init__(
3736
output_binding_names: List[str] = [],
3837
target_device: Device = Device._current_device(),
3938
):
40-
"""__init__ method for torch_tensorrt.TRTModuleNext
39+
"""__init__ method for torch_tensorrt.TorchTensorRTModule
4140
4241
Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
4342
a PyTorch ``torch.nn.Module`` around it.
@@ -71,9 +70,9 @@ def __init__(
7170
7271
"""
7372
logger.warning(
74-
"TRTModuleNext should be considered experimental stability, APIs are subject to change. Note: TRTModuleNext only supports engines built with explict batch"
73+
"TorchTensorRTModule should be considered experimental stability, APIs are subject to change. Note: TorchTensorRTModule only supports engines built with explict batch"
7574
)
76-
super(TRTModuleNext, self).__init__()
75+
super(TorchTensorRTModule, self).__init__()
7776

7877
if not isinstance(serialized_engine, bytearray):
7978
ValueError("Expected serialized engine as bytearray")
@@ -89,8 +88,8 @@ def __init__(
8988
self.name + "_engine" if self.name != "" else "tensorrt_engine",
9089
target_device._to_serialized_rt_device(),
9190
serialized_engine,
92-
TRTModuleNext._pack_binding_names(self.input_binding_names),
93-
TRTModuleNext._pack_binding_names(self.output_binding_names),
91+
TorchTensorRTModule._pack_binding_names(self.input_binding_names),
92+
TorchTensorRTModule._pack_binding_names(self.output_binding_names),
9493
]
9594
)
9695
else:
@@ -154,7 +153,7 @@ def is_non_tensor(i: Tuple[Any, bool]) -> bool:
154153

155154
non_tensors = [i[0] for i in filter(zip(inputs, types), is_non_tensor)]
156155
raise RuntimeError(
157-
f"TRTModuleNext expects a flattened list of tensors as input, found non tensors: {non_tensors}"
156+
f"TorchTensorRTModule expects a flattened list of tensors as input, found non tensors: {non_tensors}"
158157
)
159158

160159
outputs = torch.ops.tensorrt.execute_engine(list(inputs), self.engine)

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from torch_tensorrt.dynamo import fx_ts_compat
22
from .backend import compile
3+
from ._TorchTensorRTModule import TorchTensorRTModule

py/torch_tensorrt/dynamo/backend/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
WORKSPACE_SIZE,
1818
MIN_BLOCK_SIZE,
1919
PASS_THROUGH_BUILD_FAILURES,
20+
USE_EXPERIMENTAL_RT,
2021
)
2122

2223

@@ -45,6 +46,7 @@ def compile(
4546
min_block_size=MIN_BLOCK_SIZE,
4647
torch_executed_ops=[],
4748
torch_executed_modules=[],
49+
use_experimental_rt=USE_EXPERIMENTAL_RT,
4850
**kwargs,
4951
):
5052
if debug:
@@ -57,6 +59,13 @@ def compile(
5759
+ "torch_executed_ops, pass_through_build_failures}"
5860
)
5961

62+
if "use_experimental_fx_rt" in kwargs:
63+
logger.info(
64+
"Detected option 'use_experimental_fx_rt' in kwargs, "
65+
+ "overwriting the 'use_experimental_rt' argument."
66+
)
67+
use_experimental_rt = kwargs["use_experimental_fx_rt"]
68+
6069
if not isinstance(inputs, collections.abc.Sequence):
6170
inputs = [inputs]
6271

@@ -86,6 +95,7 @@ def compile(
8695
workspace_size=workspace_size,
8796
min_block_size=min_block_size,
8897
torch_executed_ops=torch_executed_ops,
98+
use_experimental_rt=use_experimental_rt,
8999
**kwargs,
90100
)
91101

@@ -109,6 +119,7 @@ def create_backend(
109119
min_block_size: int = MIN_BLOCK_SIZE,
110120
torch_executed_ops: Sequence[str] = set(),
111121
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
122+
use_experimental_rt: bool = USE_EXPERIMENTAL_RT,
112123
**kwargs,
113124
):
114125
"""Create torch.compile backend given specified arguments
@@ -120,6 +131,7 @@ def create_backend(
120131
min_block_size: Minimum number of operators per TRT-Engine Block
121132
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
122133
pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False)
134+
use_experimental_rt: Whether to use the new experimental TRTModuleNext for TRT engines
123135
Returns:
124136
Backend for torch.compile
125137
"""
@@ -133,6 +145,7 @@ def create_backend(
133145
min_block_size=min_block_size,
134146
torch_executed_ops=torch_executed_ops,
135147
pass_through_build_failures=pass_through_build_failures,
148+
use_experimental_rt=use_experimental_rt,
136149
)
137150

138151
return partial(

py/torch_tensorrt/dynamo/backend/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
WORKSPACE_SIZE = 0
77
MIN_BLOCK_SIZE = 5
88
PASS_THROUGH_BUILD_FAILURES = False
9+
USE_EXPERIMENTAL_RT = False

py/torch_tensorrt/dynamo/backend/_settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
WORKSPACE_SIZE,
99
MIN_BLOCK_SIZE,
1010
PASS_THROUGH_BUILD_FAILURES,
11+
USE_EXPERIMENTAL_RT,
1112
)
1213

1314

@@ -19,3 +20,4 @@ class CompilationSettings:
1920
min_block_size: int = MIN_BLOCK_SIZE
2021
torch_executed_ops: Sequence[str] = field(default_factory=set)
2122
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES
23+
use_experimental_rt: bool = USE_EXPERIMENTAL_RT

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def _compile_module(
135135
submodule,
136136
submodule_inputs,
137137
settings=settings,
138+
name=name,
138139
)
139140

140141
# Replace FX Module with TRT Module

py/torch_tensorrt/dynamo/backend/conversion.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from typing import Sequence, Union
22
import torch
3+
import io
34
from torch_tensorrt.fx.trt_module import TRTModule
4-
from torch_tensorrt import TRTModuleNext
5+
from torch_tensorrt.dynamo import TorchTensorRTModule
56
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
67
from torch_tensorrt.dynamo.fx_ts_compat.fx2trt import (
78
InputTensorSpec,
@@ -15,12 +16,14 @@ def convert_module(
1516
module: torch.fx.GraphModule,
1617
inputs: Sequence[torch.Tensor],
1718
settings: CompilationSettings = CompilationSettings(),
18-
) -> Union[TRTModuleNext, TRTModule]:
19+
name: str = "",
20+
) -> Union[TorchTensorRTModule, TRTModule]:
1921
"""Convert an FX module to a TRT module
2022
Args:
2123
module: FX GraphModule to convert
2224
inputs: Sequence of Tensors representing inputs to the module
2325
settings: Compilation settings
26+
name: TRT engine name
2427
Returns:
2528
TRTModule or TRTModuleNext
2629
"""
@@ -41,8 +44,19 @@ def convert_module(
4144
),
4245
)
4346

44-
return TRTModule(
45-
engine=interpreter_result.engine,
46-
input_names=interpreter_result.input_names,
47-
output_names=interpreter_result.output_names,
48-
)
47+
if settings.use_experimental_rt:
48+
with io.BytesIO() as engine_bytes:
49+
engine_bytes.write(interpreter_result.engine.serialize())
50+
engine_str = engine_bytes.getvalue()
51+
return TorchTensorRTModule(
52+
serialized_engine=engine_str,
53+
name=name,
54+
input_binding_names=interpreter_result.input_names,
55+
output_binding_names=interpreter_result.output_names,
56+
)
57+
else:
58+
return TRTModule(
59+
engine=interpreter_result.engine,
60+
input_names=interpreter_result.input_names,
61+
output_names=interpreter_result.output_names,
62+
)

0 commit comments

Comments
 (0)