4
4
import torch_tensorrt
5
5
from functools import partial
6
6
7
- from typing import Any , Sequence
7
+ from typing import Any , Optional , Sequence
8
8
from torch_tensorrt import EngineCapability , Device
9
9
from torch_tensorrt .fx .utils import LowerPrecision
10
10
16
16
WORKSPACE_SIZE ,
17
17
MIN_BLOCK_SIZE ,
18
18
PASS_THROUGH_BUILD_FAILURES ,
19
+ MAX_AUX_STREAMS ,
20
+ VERSION_COMPATIBLE ,
21
+ OPTIMIZATION_LEVEL ,
19
22
USE_EXPERIMENTAL_RT ,
20
23
)
21
24
@@ -46,6 +49,9 @@ def compile(
46
49
torch_executed_ops = [],
47
50
torch_executed_modules = [],
48
51
pass_through_build_failures = PASS_THROUGH_BUILD_FAILURES ,
52
+ max_aux_streams = MAX_AUX_STREAMS ,
53
+ version_compatible = VERSION_COMPATIBLE ,
54
+ optimization_level = OPTIMIZATION_LEVEL ,
49
55
use_experimental_rt = USE_EXPERIMENTAL_RT ,
50
56
** kwargs ,
51
57
):
@@ -98,6 +104,9 @@ def compile(
98
104
min_block_size = min_block_size ,
99
105
torch_executed_ops = torch_executed_ops ,
100
106
pass_through_build_failures = pass_through_build_failures ,
107
+ max_aux_streams = max_aux_streams ,
108
+ version_compatible = version_compatible ,
109
+ optimization_level = optimization_level ,
101
110
use_experimental_rt = use_experimental_rt ,
102
111
** kwargs ,
103
112
)
@@ -122,6 +131,9 @@ def create_backend(
122
131
min_block_size : int = MIN_BLOCK_SIZE ,
123
132
torch_executed_ops : Sequence [str ] = set (),
124
133
pass_through_build_failures : bool = PASS_THROUGH_BUILD_FAILURES ,
134
+ max_aux_streams : Optional [int ] = MAX_AUX_STREAMS ,
135
+ version_compatible : bool = VERSION_COMPATIBLE ,
136
+ optimization_level : Optional [int ] = OPTIMIZATION_LEVEL ,
125
137
use_experimental_rt : bool = USE_EXPERIMENTAL_RT ,
126
138
** kwargs ,
127
139
):
@@ -134,6 +146,10 @@ def create_backend(
134
146
min_block_size: Minimum number of operators per TRT-Engine Block
135
147
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
136
148
pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False)
149
+ max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
150
+ version_compatible: Provide version forward-compatibility for engine plan files
151
+ optimization_level: Builder optimization 0-5, higher levels imply longer build time,
152
+ searching for more optimization options. TRT defaults to 3
137
153
use_experimental_rt: Whether to use the new experimental TRTModuleNext for TRT engines
138
154
Returns:
139
155
Backend for torch.compile
@@ -146,5 +162,8 @@ def create_backend(
146
162
min_block_size = min_block_size ,
147
163
torch_executed_ops = torch_executed_ops ,
148
164
pass_through_build_failures = pass_through_build_failures ,
165
+ max_aux_streams = max_aux_streams ,
166
+ version_compatible = version_compatible ,
167
+ optimization_level = optimization_level ,
149
168
use_experimental_rt = use_experimental_rt ,
150
169
)
0 commit comments