Skip to content

Commit d718464

Browse files
committed
refactor
1 parent 271143c commit d718464

File tree

13 files changed

+202
-65
lines changed

13 files changed

+202
-65
lines changed

py/torch_tensorrt/dynamo/_engine_cache.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def pack(
107107
input_specs: Sequence[Input],
108108
compilation_settings: CompilationSettings,
109109
weight_name_map: Optional[Dict[Any, Any]],
110-
engine_is_dds: bool,
110+
requires_output_allocator: bool,
111111
) -> bytes:
112112
"""Pack serialized engine, input names, output names, and weight map into a single blob
113113
@@ -118,7 +118,7 @@ def pack(
118118
input_specs (Sequence[Input]): input specs of TRT engine
119119
compilation_settings (CompilationSettings): compilation settings of TRT engine
120120
weight_name_map (Optional[Dict[Any, Any]]): weight name map for refitting
121-
engine_is_dds (bool): whether the engine is data-dependent shape
121+
requires_output_allocator (bool): whether the engine requires output allocator
122122
Returns:
123123
bytes: packed blob
124124
"""
@@ -132,7 +132,7 @@ def pack(
132132
"input_specs": input_specs,
133133
"compilation_settings": settings,
134134
"weight_name_map": weight_name_map,
135-
"engine_is_dds": engine_is_dds,
135+
"requires_output_allocator": requires_output_allocator,
136136
}
137137
)
138138

@@ -154,7 +154,7 @@ def unpack(packed_obj: bytes) -> UnpackedCacheHit:
154154
unpacked["input_specs"],
155155
unpacked["compilation_settings"],
156156
unpacked["weight_name_map"],
157-
unpacked["engine_is_dds"],
157+
unpacked["requires_output_allocator"],
158158
)
159159

160160
def insert(

py/torch_tensorrt/dynamo/conversion/_ConversionContext.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ class ConversionContext:
1111
Args:
1212
net: TensorRT Network being built
1313
compilation_settings: Settings selected by the user for compilation
14+
requires_output_allocator: Whether the network requires output allocator
1415
"""
1516

1617
net: TRTNetwork
1718
compilation_settings: CompilationSettings = field(
1819
default_factory=CompilationSettings
1920
)
21+
requires_output_allocator: bool = False

py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
cast,
1919
)
2020

21+
import tensorrt as trt
2122
import torch
2223
from torch import SymBool, SymFloat, SymInt
2324
from torch._ops import OpOverloadPacket
@@ -26,8 +27,6 @@
2627
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
2728
from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS
2829

29-
import tensorrt as trt
30-
3130
logger = logging.getLogger(__name__)
3231

3332
LegacyConverterImplSignature = Callable[
@@ -81,13 +80,15 @@ class ConverterSupport:
8180
whether that node can be supported by its companion converter. Note that
8281
this function must not modify the node or its graph
8382
supports_dynamic_shapes: Boolean flag indicating if the converter has support for dynamic inputs.
83+
requires_output_allocator: Boolean flag indicating if the converter requires to run in output allocator.
8484
"""
8585

8686
converter_implementation: ConverterImplSignature
8787
capability_validator: Callable[[Node, CompilationSettings], bool] = field(
8888
default=lambda node, compilation_settings: True
8989
)
9090
supports_dynamic_shapes: bool = False
91+
requires_output_allocator: bool = False
9192

9293

9394
# Dictionary representing Dynamo aten-only converters
@@ -197,6 +198,7 @@ def dynamo_tensorrt_converter(
197198
capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None,
198199
priority: ConverterPriority = ConverterPriority.STANDARD,
199200
supports_dynamic_shapes: bool = False,
201+
requires_output_allocator: bool = False,
200202
) -> Callable[[ConverterImplSignature], ConverterImplSignature]:
201203
"""Decorator for Dynamo TensorRT Converter
202204
@@ -212,6 +214,8 @@ def dynamo_tensorrt_converter(
212214
this means all nodes of "key" kind can be supported by this converter
213215
priority: Converter's level of priority relative to other converters with the
214216
same target
217+
supports_dynamic_shapes: Boolean flag indicating if the converter has support for dynamic shapes.
218+
requires_output_allocator: Boolean flag indicating if the converter requires to run in output allocator.
215219
Returns:
216220
The converter being decorated
217221
"""
@@ -225,6 +229,7 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat
225229
converter_support = ConverterSupport(
226230
converter_implementation=converter,
227231
supports_dynamic_shapes=supports_dynamic_shapes,
232+
requires_output_allocator=requires_output_allocator,
228233
)
229234
else:
230235
assert callable(
@@ -234,6 +239,7 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat
234239
converter_implementation=converter,
235240
capability_validator=capability_validator,
236241
supports_dynamic_shapes=supports_dynamic_shapes,
242+
requires_output_allocator=requires_output_allocator,
237243
)
238244

239245
# OpOverloadPackets are only valid if they have a single overload, or
@@ -404,7 +410,7 @@ def __getitem_without_validation__(
404410
def __getitem__(
405411
self, node: Node
406412
) -> Tuple[
407-
Any, CallingConvention
413+
Any, CallingConvention, bool
408414
]: # TODO: Narrow to ConverterImplSignature this when we can remove FX converters
409415
"""Get the first-found validated converter in any registry
410416
@@ -462,6 +468,7 @@ def __getitem__(
462468
return (
463469
candidate.converter_implementation,
464470
calling_convention,
471+
candidate.requires_output_allocator,
465472
)
466473
else:
467474
logger.debug(
@@ -471,7 +478,11 @@ def __getitem__(
471478
else:
472479
# Assuming FX converters don't have dynamic shapes supported
473480
if not node_has_dynamic_shapes(node):
474-
return converters, calling_convention
481+
return (
482+
converters,
483+
calling_convention,
484+
candidate.requires_output_allocator,
485+
)
475486

476487
raise KeyError(
477488
f"None of the converter registries have a validated entry for {key}, with node {node}"
@@ -495,7 +506,7 @@ def get_unvalidated(
495506
def get(
496507
self, node: Node, value: Optional[ConverterImplSignature] = None
497508
) -> Union[
498-
Any, Tuple[Any, CallingConvention]
509+
Any, Tuple[Any, CallingConvention, bool]
499510
]: # TODO: Narrow to ConverterImplSignature this when we can remove FX converters
500511
"""Get validated converter for input node with a default return"""
501512
try:

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class TRTInterpreterResult(NamedTuple):
6464
input_names: Sequence[str]
6565
output_names: Sequence[str]
6666
weight_name_map: Optional[dict[Any, Any]]
67-
engine_is_dds: bool
67+
requires_output_allocator: bool
6868

6969

7070
class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc]
@@ -139,9 +139,6 @@ def __init__(
139139
# Engine cache for storing and reusing TRT engines
140140
self.engine_cache = engine_cache
141141

142-
# Whether the engine is data-dependent shape (dds)
143-
self.engine_is_dds: bool = False
144-
145142
def validate_conversion(self) -> Set[str]:
146143
missing_converters: Set[str] = set()
147144

@@ -581,7 +578,7 @@ def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> No
581578
self.input_specs,
582579
self.compilation_settings,
583580
self.weight_name_map,
584-
self.engine_is_dds,
581+
self.ctx.requires_output_allocator,
585582
),
586583
)
587584

@@ -596,7 +593,7 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
596593
cached_engine_input_specs,
597594
engine_compilation_settings,
598595
self.weight_name_map,
599-
self.engine_is_dds,
596+
self.ctx.requires_output_allocator,
600597
) = cached_data
601598

602599
setting_compatiblity, incompattible_settings = settings_are_compatible(
@@ -658,20 +655,10 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
658655
self._input_names,
659656
self._output_names,
660657
self.weight_name_map,
661-
self.engine_is_dds,
658+
self.ctx.requires_output_allocator,
662659
)
663660
return None
664661

665-
def check_dds(self, serialized_engine: bytes, output_names: List[str]) -> bool:
666-
runtime = trt.Runtime(TRT_LOGGER)
667-
engine = runtime.deserialize_cuda_engine(serialized_engine)
668-
669-
for output_name in output_names:
670-
output_shape = engine.get_tensor_shape(output_name)
671-
if -1 in output_shape:
672-
return True
673-
return False
674-
675662
def run(
676663
self,
677664
strict_type_constraints: bool = False,
@@ -728,8 +715,6 @@ def run(
728715
)
729716
assert serialized_engine
730717

731-
self.engine_is_dds = self.check_dds(serialized_engine, self._output_names)
732-
733718
_LOGGER.info(
734719
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
735720
)
@@ -756,7 +741,7 @@ def run(
756741
self._input_names,
757742
self._output_names,
758743
self.weight_name_map,
759-
self.engine_is_dds,
744+
self.ctx.requires_output_allocator,
760745
)
761746

762747
def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
@@ -850,7 +835,7 @@ def call_module(
850835
f"Conversion of module of type {submod_type} not currently supported!"
851836
)
852837

853-
converter, calling_convention = converter_packet
838+
converter, calling_convention, requires_output_allocator = converter_packet
854839

855840
assert self._cur_node_name is not None
856841

@@ -867,7 +852,10 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
867852
f"Conversion of function {torch.typename(target)} not currently supported!"
868853
)
869854

870-
converter, calling_convention = converter_packet
855+
converter, calling_convention, requires_output_allocator = converter_packet
856+
if requires_output_allocator:
857+
self.ctx.requires_output_allocator = True
858+
_LOGGER.debug(f"{target} requires output allocator")
871859

872860
if calling_convention is CallingConvention.LEGACY:
873861
return converter(self.ctx.net, target, args, kwargs, self._cur_node_name)
@@ -897,7 +885,7 @@ def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
897885
raise UnsupportedOperatorException(
898886
f"Conversion of method {target} not currently supported!"
899887
)
900-
converter, calling_convention = converter_packet
888+
converter, calling_convention, requires_output_allocator = converter_packet
901889

902890
if calling_convention is CallingConvention.LEGACY:
903891
return converter(self.ctx.net, target, args, kwargs, self._cur_node_name)

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,5 +112,5 @@ def convert_module(
112112
name=name,
113113
settings=settings,
114114
weight_name_map=interpreter_result.weight_name_map,
115-
engine_is_dds=interpreter_result.engine_is_dds,
115+
requires_output_allocator=interpreter_result.requires_output_allocator,
116116
)

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3554,7 +3554,11 @@ def aten_ops_full(
35543554
)
35553555

35563556

3557-
@dynamo_tensorrt_converter(torch.ops.aten.nonzero.default, supports_dynamic_shapes=True)
3557+
@dynamo_tensorrt_converter(
3558+
torch.ops.aten.nonzero.default,
3559+
supports_dynamic_shapes=True,
3560+
requires_output_allocator=True,
3561+
)
35583562
def aten_ops_nonzero(
35593563
ctx: ConversionContext,
35603564
target: Target,

py/torch_tensorrt/dynamo/partitioning/common.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ def construct_dynamic_input(
3131
if isinstance(dim, torch.SymInt):
3232
min_max_opt = extract_var_range_info(dim)
3333
min_shape.append(min_max_opt["min"])
34-
# opt might not exist
35-
opt_shape.append(min_max_opt.get("opt"))
34+
# if opt not exist, set it to the mean of min and max
35+
opt_shape.append(
36+
min_max_opt.get("opt", int(min_max_opt["min"] + min_max_opt["max"] / 2))
37+
)
3638
max_shape.append(min_max_opt["max"])
3739
else:
3840
min_shape.append(dim)

py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,6 @@ def set_output_allocator_outputs(self, enable: bool) -> None:
7979

8080
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
8181
cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode()
82-
if cudagraphs_enabled and self.use_output_allocator_outputs:
83-
raise RuntimeError(
84-
"There are non-TRT submodules in the module. OutputAllocator is not compatible with modules with non-TRT submodules."
85-
)
86-
8782
if cudagraphs_enabled:
8883
shape_changed = self.validate_input_shapes(inputs)
8984
need_cudagraphs_record = shape_changed or self.is_weight_streaming_set

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def __init__(
127127
name: str = "",
128128
settings: CompilationSettings = CompilationSettings(),
129129
weight_name_map: Optional[dict[Any, Any]] = None,
130-
engine_is_dds: bool = False,
130+
requires_output_allocator: bool = False,
131131
):
132132
"""Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
133133
a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine
@@ -141,7 +141,7 @@ def __init__(
141141
name (str): Name for module
142142
settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed
143143
weight_name_map (dict): Mapping of engine weight name to state_dict weight name
144-
engine_is_dds (bool): Whether the engine is Data Dependent Shape
144+
requires_output_allocator (bool): Whether the engine requires an output allocator
145145
146146
Example:
147147
@@ -206,7 +206,7 @@ def __init__(
206206
self.pre_allocated_outputs: List[torch.Tensor] = []
207207
self.use_pre_allocated_outputs = False
208208

209-
self.engine_is_dds = engine_is_dds
209+
self.requires_output_allocator = requires_output_allocator
210210
self.output_allocator: Optional[DynamicOutputAllocator] = None
211211
self.use_output_allocator_outputs = False
212212

@@ -281,7 +281,7 @@ def setup_engine(self) -> None:
281281
for output_name in self.output_names
282282
]
283283

284-
if self.engine_is_dds:
284+
if self.requires_output_allocator:
285285
self.create_output_allocator()
286286

287287
if torch_tensorrt.runtime.get_cudagraphs_mode():
@@ -678,22 +678,20 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
678678
]
679679
logger.warning(f"Moved all input Tensors to cuda:{device_id}")
680680

681-
if self.engine_is_dds:
681+
if self.requires_output_allocator:
682682
if self.cudagraphs_enabled:
683683
raise RuntimeError(
684-
"The module is Data-Dependent Shape (DDS). It has to be handled by OutputAllocator which is not compatible with CUDA Graphs. Please disable CUDA Graphs."
684+
"This module requires OutputAllocator which is not compatible with CUDA Graphs. Please disable CUDA Graphs."
685685
)
686-
logger.debug(
687-
"The module is Data-Dependent Shape (DDS). Using output allocator."
688-
)
686+
logger.debug("Using OutputAllocator in runtime.")
689687
return run_output_allocator()
690688
else:
691689
if self.cudagraphs_enabled and self.use_output_allocator_outputs:
692690
raise RuntimeError(
693691
"Both CUDA Graphs and OutputAllocator are enabled. Please disable either one."
694692
)
695693
if self.use_output_allocator_outputs:
696-
logger.debug("Using output allocator.")
694+
logger.debug("Using OutputAllocator in runtime.")
697695
return run_output_allocator()
698696
logger.debug(
699697
f"Using standard execution with cudagraphs={self.cudagraphs_enabled}."

0 commit comments

Comments
 (0)