10
10
import torch_tensorrt .fx .tracer .dispatch_tracer .aten_tracer as aten_tracer
11
11
from torch .fx .passes .splitter_base import SplitResult
12
12
13
- from . fx2trt import TRTInterpreter , TRTInterpreterResult
13
+ from torch_tensorrt . dynamo . common import TRTInterpreter , TRTInterpreterResult
14
14
from .lower_setting import LowerSetting
15
15
from .passes .lower_pass_manager_builder import LowerPassManagerBuilder
16
16
from .passes .pass_utils import PassFunc , validate_inference
21
21
from torch_tensorrt .fx .trt_module import TRTModule
22
22
from torch_tensorrt .fx .utils import LowerPrecision
23
23
from torch_tensorrt ._Device import Device
24
+ from torch_tensorrt .dynamo ._defaults import (
25
+ PRECISION ,
26
+ DEBUG ,
27
+ WORKSPACE_SIZE ,
28
+ MIN_BLOCK_SIZE ,
29
+ PASS_THROUGH_BUILD_FAILURES ,
30
+ MAX_AUX_STREAMS ,
31
+ VERSION_COMPATIBLE ,
32
+ OPTIMIZATION_LEVEL ,
33
+ USE_EXPERIMENTAL_RT ,
34
+ )
24
35
25
36
logger = logging .getLogger (__name__ )
26
37
@@ -34,24 +45,25 @@ def compile(
34
45
disable_tf32 = False ,
35
46
sparse_weights = False ,
36
47
enabled_precisions = set (),
37
- min_block_size : int = 3 ,
38
- workspace_size = 0 ,
48
+ min_block_size : int = MIN_BLOCK_SIZE ,
49
+ workspace_size = WORKSPACE_SIZE ,
39
50
dla_sram_size = 1048576 ,
40
51
dla_local_dram_size = 1073741824 ,
41
52
dla_global_dram_size = 536870912 ,
42
53
calibrator = None ,
43
54
truncate_long_and_double = False ,
44
55
require_full_compilation = False ,
45
- debug = False ,
56
+ explicit_batch_dimension = False ,
57
+ debug = DEBUG ,
46
58
refit = False ,
47
59
timing_cache_prefix = "" ,
48
60
save_timing_cache = False ,
49
61
cuda_graph_batch_size = - 1 ,
50
62
is_aten = False ,
51
- use_experimental_fx_rt = False ,
52
- max_aux_streams = None ,
53
- version_compatible = False ,
54
- optimization_level = None ,
63
+ use_experimental_rt = USE_EXPERIMENTAL_RT ,
64
+ max_aux_streams = MAX_AUX_STREAMS ,
65
+ version_compatible = VERSION_COMPATIBLE ,
66
+ optimization_level = OPTIMIZATION_LEVEL ,
55
67
num_avg_timing_iters = 1 ,
56
68
torch_executed_ops = [],
57
69
torch_executed_modules = [],
@@ -70,7 +82,7 @@ def compile(
70
82
timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
71
83
save_timing_cache: Update timing cache with current timing cache data if set to True.
72
84
cuda_graph_batch_size: Cuda graph batch size, default to be -1.
73
- use_experimental_fx_rt : Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
85
+ use_experimental_rt : Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
74
86
max_aux_streams: max number of aux stream to use
75
87
version_compatible: enable version compatible feature
76
88
optimization_level: builder optimization level
@@ -123,7 +135,7 @@ def compile(
123
135
save_timing_cache = save_timing_cache ,
124
136
cuda_graph_batch_size = cuda_graph_batch_size ,
125
137
is_aten = is_aten ,
126
- use_experimental_rt = use_experimental_fx_rt ,
138
+ use_experimental_rt = use_experimental_rt ,
127
139
max_aux_streams = max_aux_streams ,
128
140
version_compatible = version_compatible ,
129
141
optimization_level = optimization_level ,
0 commit comments