Skip to content

Commit f93a732

Browse files
committed
fix/feat: Add support for multiple TRT Build Args (#2510)
1 parent fdd6bad commit f93a732

File tree

7 files changed

+204
-54
lines changed

7 files changed

+204
-54
lines changed

py/torch_tensorrt/dynamo/_compiler.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,21 @@
1616
from torch_tensorrt.dynamo._defaults import (
1717
DEBUG,
1818
DEVICE,
19+
DISABLE_TF32,
20+
DLA_GLOBAL_DRAM_SIZE,
21+
DLA_LOCAL_DRAM_SIZE,
22+
DLA_SRAM_SIZE,
1923
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
24+
ENGINE_CAPABILITY,
2025
MAX_AUX_STREAMS,
2126
MIN_BLOCK_SIZE,
27+
NUM_AVG_TIMING_ITERS,
2228
OPTIMIZATION_LEVEL,
2329
PASS_THROUGH_BUILD_FAILURES,
2430
PRECISION,
31+
REFIT,
2532
REQUIRE_FULL_COMPILATION,
33+
SPARSE_WEIGHTS,
2634
TRUNCATE_LONG_AND_DOUBLE,
2735
USE_FAST_PARTITIONER,
2836
USE_PYTHON_RUNTIME,
@@ -51,17 +59,18 @@ def compile(
5159
inputs: Any,
5260
*,
5361
device: Optional[Union[Device, torch.device, str]] = DEVICE,
54-
disable_tf32: bool = False,
55-
sparse_weights: bool = False,
62+
disable_tf32: bool = DISABLE_TF32,
63+
sparse_weights: bool = SPARSE_WEIGHTS,
5664
enabled_precisions: Set[torch.dtype] | Tuple[torch.dtype] = (torch.float32,),
57-
refit: bool = False,
65+
engine_capability: EngineCapability = ENGINE_CAPABILITY,
66+
refit: bool = REFIT,
5867
debug: bool = DEBUG,
5968
capability: EngineCapability = EngineCapability.default,
60-
num_avg_timing_iters: int = 1,
69+
num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS,
6170
workspace_size: int = WORKSPACE_SIZE,
62-
dla_sram_size: int = 1048576,
63-
dla_local_dram_size: int = 1073741824,
64-
dla_global_dram_size: int = 536870912,
71+
dla_sram_size: int = DLA_SRAM_SIZE,
72+
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE,
73+
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE,
6574
calibrator: object = None,
6675
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
6776
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
@@ -200,6 +209,13 @@ def compile(
200209
"use_fast_partitioner": use_fast_partitioner,
201210
"enable_experimental_decompositions": enable_experimental_decompositions,
202211
"require_full_compilation": require_full_compilation,
212+
"disable_tf32": disable_tf32,
213+
"sparse_weights": sparse_weights,
214+
"refit": refit,
215+
"engine_capability": engine_capability,
216+
"dla_sram_size": dla_sram_size,
217+
"dla_local_dram_size": dla_local_dram_size,
218+
"dla_global_dram_size": dla_global_dram_size,
203219
}
204220

205221
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/_defaults.py

+9
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,28 @@
11
import torch
2+
from tensorrt import EngineCapability
23
from torch_tensorrt._Device import Device
34

45
PRECISION = torch.float32
56
DEBUG = False
67
DEVICE = None
8+
DISABLE_TF32 = False
9+
DLA_LOCAL_DRAM_SIZE = 1073741824
10+
DLA_GLOBAL_DRAM_SIZE = 536870912
11+
DLA_SRAM_SIZE = 1048576
12+
ENGINE_CAPABILITY = EngineCapability.STANDARD
713
WORKSPACE_SIZE = 0
814
MIN_BLOCK_SIZE = 5
915
PASS_THROUGH_BUILD_FAILURES = False
1016
MAX_AUX_STREAMS = None
17+
NUM_AVG_TIMING_ITERS = 1
1118
VERSION_COMPATIBLE = False
1219
OPTIMIZATION_LEVEL = None
20+
SPARSE_WEIGHTS = False
1321
TRUNCATE_LONG_AND_DOUBLE = False
1422
USE_PYTHON_RUNTIME = False
1523
USE_FAST_PARTITIONER = True
1624
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
25+
REFIT = False
1726
REQUIRE_FULL_COMPILATION = False
1827

1928

py/torch_tensorrt/dynamo/_settings.py

+25
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,25 @@
22
from typing import Optional, Set
33

44
import torch
5+
from tensorrt import EngineCapability
56
from torch_tensorrt._Device import Device
67
from torch_tensorrt.dynamo._defaults import (
78
DEBUG,
9+
DISABLE_TF32,
10+
DLA_GLOBAL_DRAM_SIZE,
11+
DLA_LOCAL_DRAM_SIZE,
12+
DLA_SRAM_SIZE,
813
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
14+
ENGINE_CAPABILITY,
915
MAX_AUX_STREAMS,
1016
MIN_BLOCK_SIZE,
17+
NUM_AVG_TIMING_ITERS,
1118
OPTIMIZATION_LEVEL,
1219
PASS_THROUGH_BUILD_FAILURES,
1320
PRECISION,
21+
REFIT,
1422
REQUIRE_FULL_COMPILATION,
23+
SPARSE_WEIGHTS,
1524
TRUNCATE_LONG_AND_DOUBLE,
1625
USE_FAST_PARTITIONER,
1726
USE_PYTHON_RUNTIME,
@@ -46,6 +55,14 @@ class CompilationSettings:
4655
device (Device): GPU to compile the model on
4756
require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT.
4857
Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path
58+
disable_tf32 (bool): Whether to disable TF32 computation for TRT layers
59+
sparse_weights (bool): Whether to allow the builder to use sparse weights
60+
refit (bool): Whether to build a refittable engine
61+
engine_capability (trt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
62+
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
63+
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
64+
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
65+
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
4966
"""
5067

5168
precision: torch.dtype = PRECISION
@@ -63,3 +80,11 @@ class CompilationSettings:
6380
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
6481
device: Device = field(default_factory=default_device)
6582
require_full_compilation: bool = REQUIRE_FULL_COMPILATION
83+
disable_tf32: bool = DISABLE_TF32
84+
sparse_weights: bool = SPARSE_WEIGHTS
85+
refit: bool = REFIT
86+
engine_capability: EngineCapability = ENGINE_CAPABILITY
87+
num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS
88+
dla_sram_size: int = DLA_SRAM_SIZE
89+
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE
90+
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+47-31
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
55

66
import numpy as np
7-
8-
# @manual=//deeplearning/trt/python:py_tensorrt
97
import tensorrt as trt
108
import torch
119
import torch.fx
@@ -97,6 +95,7 @@ def __init__(
9795
self._itensor_to_tensor_meta: Dict[
9896
trt.tensorrt.ITensor, TensorMetadata
9997
] = dict()
98+
self.compilation_settings = compilation_settings
10099

101100
# Data types for TRT Module output Tensors
102101
self.output_dtypes = output_dtypes
@@ -119,40 +118,25 @@ def validate_conversion(self) -> Set[str]:
119118

120119
def run(
121120
self,
122-
workspace_size: int = 0,
123-
precision: torch.dtype = torch.float32, # TODO: @peri044 Needs to be expanded to set
124-
sparse_weights: bool = False,
125-
disable_tf32: bool = False,
126121
force_fp32_output: bool = False,
127122
strict_type_constraints: bool = False,
128123
algorithm_selector: Optional[trt.IAlgorithmSelector] = None,
129124
timing_cache: Optional[trt.ITimingCache] = None,
130-
profiling_verbosity: Optional[trt.ProfilingVerbosity] = None,
131125
tactic_sources: Optional[int] = None,
132-
max_aux_streams: Optional[int] = None,
133-
version_compatible: bool = False,
134-
optimization_level: Optional[int] = None,
135126
) -> TRTInterpreterResult:
136127
"""
137128
Build TensorRT engine with some configs.
138129
Args:
139-
workspace_size: Amount of memory used by TensorRT to store intermediate buffers within an operation.
140-
precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
141-
sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity
142130
force_fp32_output: force output to be fp32
143131
strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons.
144132
algorithm_selector: set up algorithm selection for certain layer
145133
timing_cache: enable timing cache for TensorRT
146-
profiling_verbosity: TensorRT logging level
147-
max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
148-
version_compatible: Provide version forward-compatibility for engine plan files
149-
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
150-
searching for more optimization options. TRT defaults to 3
151134
Return:
152135
TRTInterpreterResult
153136
"""
154137
TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module)
155138

139+
precision = self.compilation_settings.precision
156140
# For float outputs, we set their dtype to fp16 only if precision == torch.float16 and
157141
# force_fp32_output=False. Overriden by specifying output_dtypes
158142
self.output_fp16 = not force_fp32_output and precision == torch.float16
@@ -173,9 +157,9 @@ def run(
173157

174158
builder_config = self.builder.create_builder_config()
175159

176-
if workspace_size != 0:
160+
if self.compilation_settings.workspace_size != 0:
177161
builder_config.set_memory_pool_limit(
178-
trt.MemoryPoolType.WORKSPACE, workspace_size
162+
trt.MemoryPoolType.WORKSPACE, self.compilation_settings.workspace_size
179163
)
180164

181165
cache = None
@@ -188,34 +172,66 @@ def run(
188172

189173
if version.parse(trt.__version__) >= version.parse("8.2"):
190174
builder_config.profiling_verbosity = (
191-
profiling_verbosity
192-
if profiling_verbosity
175+
trt.ProfilingVerbosity.VERBOSE
176+
if self.compilation_settings.debug
193177
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
194178
)
195179

196180
if version.parse(trt.__version__) >= version.parse("8.6"):
197-
if max_aux_streams is not None:
198-
_LOGGER.info(f"Setting max aux streams to {max_aux_streams}")
199-
builder_config.max_aux_streams = max_aux_streams
200-
if version_compatible:
181+
if self.compilation_settings.max_aux_streams is not None:
182+
_LOGGER.info(
183+
f"Setting max aux streams to {self.compilation_settings.max_aux_streams}"
184+
)
185+
builder_config.max_aux_streams = (
186+
self.compilation_settings.max_aux_streams
187+
)
188+
if self.compilation_settings.version_compatible:
201189
_LOGGER.info("Using version compatible")
202190
builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE)
203-
if optimization_level is not None:
204-
_LOGGER.info(f"Using optimization level {optimization_level}")
205-
builder_config.builder_optimization_level = optimization_level
191+
if self.compilation_settings.optimization_level is not None:
192+
_LOGGER.info(
193+
f"Using optimization level {self.compilation_settings.optimization_level}"
194+
)
195+
builder_config.builder_optimization_level = (
196+
self.compilation_settings.optimization_level
197+
)
198+
199+
builder_config.engine_capability = self.compilation_settings.engine_capability
200+
builder_config.avg_timing_iterations = (
201+
self.compilation_settings.num_avg_timing_iters
202+
)
203+
204+
if self.compilation_settings.device.device_type == trt.DeviceType.DLA:
205+
builder_config.DLA_core = self.compilation_settings.device.dla_core
206+
_LOGGER.info(f"Using DLA core {self.compilation_settings.device.dla_core}")
207+
builder_config.set_memory_pool_limit(
208+
trt.MemoryPoolType.DLA_MANAGED_SRAM,
209+
self.compilation_settings.dla_sram_size,
210+
)
211+
builder_config.set_memory_pool_limit(
212+
trt.MemoryPoolType.DLA_LOCAL_DRAM,
213+
self.compilation_settings.dla_local_dram_size,
214+
)
215+
builder_config.set_memory_pool_limit(
216+
trt.MemoryPoolType.DLA_GLOBAL_DRAM,
217+
self.compilation_settings.dla_global_dram_size,
218+
)
206219

207220
if precision == torch.float16:
208221
builder_config.set_flag(trt.BuilderFlag.FP16)
209222

210223
if precision == torch.int8:
211224
builder_config.set_flag(trt.BuilderFlag.INT8)
212225

213-
if sparse_weights:
226+
if self.compilation_settings.sparse_weights:
214227
builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS)
215228

216-
if disable_tf32:
229+
if self.compilation_settings.disable_tf32:
217230
builder_config.clear_flag(trt.BuilderFlag.TF32)
218231

232+
if self.compilation_settings.refit:
233+
builder_config.set_flag(trt.BuilderFlag.REFIT)
234+
219235
if strict_type_constraints:
220236
builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES)
221237

py/torch_tensorrt/dynamo/conversion/_conversion.py

+1-12
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,7 @@ def convert_module(
5353
output_dtypes=output_dtypes,
5454
compilation_settings=settings,
5555
)
56-
interpreter_result = interpreter.run(
57-
workspace_size=settings.workspace_size,
58-
precision=settings.precision,
59-
profiling_verbosity=(
60-
trt.ProfilingVerbosity.VERBOSE
61-
if settings.debug
62-
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
63-
),
64-
max_aux_streams=settings.max_aux_streams,
65-
version_compatible=settings.version_compatible,
66-
optimization_level=settings.optimization_level,
67-
)
56+
interpreter_result = interpreter.run()
6857

6958
if settings.use_python_runtime:
7059
return PythonTorchTensorRTModule(

tests/py/dynamo/conversion/harness.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def run_test(
5050
interpreter,
5151
rtol,
5252
atol,
53-
precision=torch.float,
5453
check_dtype=True,
5554
):
5655
with torch.no_grad():
@@ -60,7 +59,7 @@ def run_test(
6059

6160
mod.eval()
6261
start = time.perf_counter()
63-
interpreter_result = interpreter.run(precision=precision)
62+
interpreter_result = interpreter.run()
6463
sec = time.perf_counter() - start
6564
_LOGGER.info(f"Interpreter run time(s): {sec}")
6665
trt_mod = PythonTorchTensorRTModule(
@@ -234,7 +233,9 @@ def run_test(
234233

235234
# Previous instance of the interpreter auto-casted 64-bit inputs
236235
# We replicate this behavior here
237-
compilation_settings = CompilationSettings(truncate_long_and_double=True)
236+
compilation_settings = CompilationSettings(
237+
precision=precision, truncate_long_and_double=True
238+
)
238239

239240
interp = TRTInterpreter(
240241
mod,
@@ -248,7 +249,6 @@ def run_test(
248249
interp,
249250
rtol,
250251
atol,
251-
precision,
252252
check_dtype,
253253
)
254254

0 commit comments

Comments
 (0)