4
4
from typing import Any , Callable , Dict , List , NamedTuple , Optional , Sequence , Set
5
5
6
6
import numpy as np
7
-
8
- # @manual=//deeplearning/trt/python:py_tensorrt
9
7
import tensorrt as trt
10
8
import torch
11
9
import torch .fx
@@ -97,6 +95,7 @@ def __init__(
97
95
self ._itensor_to_tensor_meta : Dict [
98
96
trt .tensorrt .ITensor , TensorMetadata
99
97
] = dict ()
98
+ self .compilation_settings = compilation_settings
100
99
101
100
# Data types for TRT Module output Tensors
102
101
self .output_dtypes = output_dtypes
@@ -119,40 +118,25 @@ def validate_conversion(self) -> Set[str]:
119
118
120
119
def run (
121
120
self ,
122
- workspace_size : int = 0 ,
123
- precision : torch .dtype = torch .float32 , # TODO: @peri044 Needs to be expanded to set
124
- sparse_weights : bool = False ,
125
- disable_tf32 : bool = False ,
126
121
force_fp32_output : bool = False ,
127
122
strict_type_constraints : bool = False ,
128
123
algorithm_selector : Optional [trt .IAlgorithmSelector ] = None ,
129
124
timing_cache : Optional [trt .ITimingCache ] = None ,
130
- profiling_verbosity : Optional [trt .ProfilingVerbosity ] = None ,
131
125
tactic_sources : Optional [int ] = None ,
132
- max_aux_streams : Optional [int ] = None ,
133
- version_compatible : bool = False ,
134
- optimization_level : Optional [int ] = None ,
135
126
) -> TRTInterpreterResult :
136
127
"""
137
128
Build TensorRT engine with some configs.
138
129
Args:
139
- workspace_size: Amount of memory used by TensorRT to store intermediate buffers within an operation.
140
- precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
141
- sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity
142
130
force_fp32_output: force output to be fp32
143
131
strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons.
144
132
algorithm_selector: set up algorithm selection for certain layer
145
133
timing_cache: enable timing cache for TensorRT
146
- profiling_verbosity: TensorRT logging level
147
- max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
148
- version_compatible: Provide version forward-compatibility for engine plan files
149
- optimization_level: Builder optimization 0-5, higher levels imply longer build time,
150
- searching for more optimization options. TRT defaults to 3
151
134
Return:
152
135
TRTInterpreterResult
153
136
"""
154
137
TRT_INTERPRETER_CALL_PRE_OBSERVER .observe (self .module )
155
138
139
+ precision = self .compilation_settings .precision
156
140
# For float outputs, we set their dtype to fp16 only if precision == torch.float16 and
157
141
# force_fp32_output=False. Overriden by specifying output_dtypes
158
142
self .output_fp16 = not force_fp32_output and precision == torch .float16
@@ -173,9 +157,9 @@ def run(
173
157
174
158
builder_config = self .builder .create_builder_config ()
175
159
176
- if workspace_size != 0 :
160
+ if self . compilation_settings . workspace_size != 0 :
177
161
builder_config .set_memory_pool_limit (
178
- trt .MemoryPoolType .WORKSPACE , workspace_size
162
+ trt .MemoryPoolType .WORKSPACE , self . compilation_settings . workspace_size
179
163
)
180
164
181
165
cache = None
@@ -188,34 +172,66 @@ def run(
188
172
189
173
if version .parse (trt .__version__ ) >= version .parse ("8.2" ):
190
174
builder_config .profiling_verbosity = (
191
- profiling_verbosity
192
- if profiling_verbosity
175
+ trt . ProfilingVerbosity . VERBOSE
176
+ if self . compilation_settings . debug
193
177
else trt .ProfilingVerbosity .LAYER_NAMES_ONLY
194
178
)
195
179
196
180
if version .parse (trt .__version__ ) >= version .parse ("8.6" ):
197
- if max_aux_streams is not None :
198
- _LOGGER .info (f"Setting max aux streams to { max_aux_streams } " )
199
- builder_config .max_aux_streams = max_aux_streams
200
- if version_compatible :
181
+ if self .compilation_settings .max_aux_streams is not None :
182
+ _LOGGER .info (
183
+ f"Setting max aux streams to { self .compilation_settings .max_aux_streams } "
184
+ )
185
+ builder_config .max_aux_streams = (
186
+ self .compilation_settings .max_aux_streams
187
+ )
188
+ if self .compilation_settings .version_compatible :
201
189
_LOGGER .info ("Using version compatible" )
202
190
builder_config .set_flag (trt .BuilderFlag .VERSION_COMPATIBLE )
203
- if optimization_level is not None :
204
- _LOGGER .info (f"Using optimization level { optimization_level } " )
205
- builder_config .builder_optimization_level = optimization_level
191
+ if self .compilation_settings .optimization_level is not None :
192
+ _LOGGER .info (
193
+ f"Using optimization level { self .compilation_settings .optimization_level } "
194
+ )
195
+ builder_config .builder_optimization_level = (
196
+ self .compilation_settings .optimization_level
197
+ )
198
+
199
+ builder_config .engine_capability = self .compilation_settings .engine_capability
200
+ builder_config .avg_timing_iterations = (
201
+ self .compilation_settings .num_avg_timing_iters
202
+ )
203
+
204
+ if self .compilation_settings .device .device_type == trt .DeviceType .DLA :
205
+ builder_config .DLA_core = self .compilation_settings .device .dla_core
206
+ _LOGGER .info (f"Using DLA core { self .compilation_settings .device .dla_core } " )
207
+ builder_config .set_memory_pool_limit (
208
+ trt .MemoryPoolType .DLA_MANAGED_SRAM ,
209
+ self .compilation_settings .dla_sram_size ,
210
+ )
211
+ builder_config .set_memory_pool_limit (
212
+ trt .MemoryPoolType .DLA_LOCAL_DRAM ,
213
+ self .compilation_settings .dla_local_dram_size ,
214
+ )
215
+ builder_config .set_memory_pool_limit (
216
+ trt .MemoryPoolType .DLA_GLOBAL_DRAM ,
217
+ self .compilation_settings .dla_global_dram_size ,
218
+ )
206
219
207
220
if precision == torch .float16 :
208
221
builder_config .set_flag (trt .BuilderFlag .FP16 )
209
222
210
223
if precision == torch .int8 :
211
224
builder_config .set_flag (trt .BuilderFlag .INT8 )
212
225
213
- if sparse_weights :
226
+ if self . compilation_settings . sparse_weights :
214
227
builder_config .set_flag (trt .BuilderFlag .SPARSE_WEIGHTS )
215
228
216
- if disable_tf32 :
229
+ if self . compilation_settings . disable_tf32 :
217
230
builder_config .clear_flag (trt .BuilderFlag .TF32 )
218
231
232
+ if self .compilation_settings .refit :
233
+ builder_config .set_flag (trt .BuilderFlag .REFIT )
234
+
219
235
if strict_type_constraints :
220
236
builder_config .set_flag (trt .BuilderFlag .STRICT_TYPES )
221
237
0 commit comments