6
6
import torch_tensorrt
7
7
from torch .fx .passes .pass_manager import PassManager
8
8
from torch .fx .passes .splitter_base import SplitResult
9
- from torch_tensorrt import Device , EngineCapability
9
+ from torch_tensorrt .fx .tools .trt_splitter import TRTSplitter , TRTSplitterSetting
10
+ from torch_tensorrt ._Device import Device
11
+ from torch_tensorrt ._enums import (
12
+ EngineCapability ,
13
+ ) # TODO: Should probabably be the TRT EngineCapability Enum
10
14
from torch_tensorrt .dynamo import CompilationSettings
11
15
from torch_tensorrt .dynamo ._defaults import (
12
16
DEBUG ,
15
19
OPTIMIZATION_LEVEL ,
16
20
PASS_THROUGH_BUILD_FAILURES ,
17
21
PRECISION ,
22
+ TRUNCATE_LONG_AND_DOUBLE ,
18
23
USE_PYTHON_RUNTIME ,
19
24
VERSION_COMPATIBLE ,
20
25
WORKSPACE_SIZE ,
21
26
)
22
27
from torch_tensorrt .dynamo .backend .backends import _compile_module
23
28
from torch_tensorrt .dynamo .conversion import convert_module
24
- < << << << HEAD
25
-
26
- from torch_tensorrt .dynamo ._defaults import (
27
- PRECISION ,
28
- DEBUG ,
29
- WORKSPACE_SIZE ,
30
- MIN_BLOCK_SIZE ,
31
- PASS_THROUGH_BUILD_FAILURES ,
32
- MAX_AUX_STREAMS ,
33
- VERSION_COMPATIBLE ,
34
- OPTIMIZATION_LEVEL ,
35
- USE_PYTHON_RUNTIME ,
36
- TRUNCATE_LONG_AND_DOUBLE ,
29
+ from torch_tensorrt .dynamo .lowering ._fusers import (
30
+ fuse_permute_linear ,
31
+ fuse_permute_matmul ,
37
32
)
38
-
39
- == == == =
40
- from torch_tensorrt .dynamo .lowering import fuse_permute_linear , fuse_permute_matmul
41
33
from torch_tensorrt .dynamo .utils import prepare_device , prepare_inputs
42
- from torch_tensorrt .fx .tools .trt_splitter import TRTSplitter , TRTSplitterSetting
43
- > >> >> >> e39abb60d (chore : adding isort to pre - commit )
44
34
45
35
logger = logging .getLogger (__name__ )
46
36
@@ -89,7 +79,7 @@ def compile(
89
79
if not isinstance (inputs , collections .abc .Sequence ):
90
80
inputs = [inputs ]
91
81
92
- torchtrt_inputs , torch_inputs = prepare_inputs (inputs , prepare_device (device ))
82
+ _ , torch_inputs = prepare_inputs (inputs , prepare_device (device ))
93
83
94
84
if (
95
85
torch .float16 in enabled_precisions
@@ -125,7 +115,7 @@ def compile(
125
115
"truncate_long_and_double" : truncate_long_and_double ,
126
116
}
127
117
128
- settings = CompilationSettings (** compilation_options ) # type: ignore[arg-type]
118
+ settings = CompilationSettings (** compilation_options )
129
119
if kwargs .get ("use_capability_partitioner" , None ):
130
120
model = lower_model (gm , torch_inputs )
131
121
return _compile_module (model , torch_inputs , settings )
@@ -163,7 +153,7 @@ def lower_model_using_trt_splitter(
163
153
) -> SplitResult :
164
154
# Perform basic lowering
165
155
model = lower_model (model , inputs )
166
- splitter_setting = TRTSplitterSetting () # type: ignore[no-untyped-call]
156
+ splitter_setting = TRTSplitterSetting ()
167
157
splitter_setting .use_implicit_batch_dim = False
168
158
splitter_setting .min_acc_module_size = 1
169
159
splitter_setting .use_experimental_rt = False
@@ -177,7 +167,7 @@ def lower_model_using_trt_splitter(
177
167
def lower_model (
178
168
model : torch .nn .Module , inputs : Any , ** kwargs : Any
179
169
) -> torch .fx .GraphModule :
180
- graph_optimization_pm = PassManager .build_from_passlist ( # type: ignore[no-untyped-call]
170
+ graph_optimization_pm = PassManager .build_from_passlist (
181
171
[fuse_permute_matmul , fuse_permute_linear ]
182
172
)
183
173
lowered_model : torch .fx .GraphModule = graph_optimization_pm (model )
0 commit comments