Skip to content

Commit 28b27c5

Browse files
committed
resolve comments
1 parent 0987146 commit 28b27c5

File tree

8 files changed

+28
-22
lines changed

8 files changed

+28
-22
lines changed

py/torch_tensorrt/dynamo/_engine_cache.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def pack(
118118
input_specs (Sequence[Input]): input specs of TRT engine
119119
compilation_settings (CompilationSettings): compilation settings of TRT engine
120120
weight_name_map (Optional[Dict[Any, Any]]): weight name map for refitting
121-
requires_output_allocator (bool): whether the engine requires output allocator
121+
requires_output_allocator (bool): Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators)
122122
Returns:
123123
bytes: packed blob
124124
"""

py/torch_tensorrt/dynamo/conversion/_ConversionContext.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class ConversionContext:
1111
Args:
1212
net: TensorRT Network being built
1313
compilation_settings: Settings selected by the user for compilation
14-
requires_output_allocator: Whether the network requires output allocator
14+
requires_output_allocator: Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators)
1515
"""
1616

1717
net: TRTNetwork

py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class ConverterSupport:
8080
whether that node can be supported by its companion converter. Note that
8181
this function must not modify the node or its graph
8282
supports_dynamic_shapes: Boolean flag indicating if the converter has support for dynamic inputs.
83-
requires_output_allocator: Boolean flag indicating if the converter requires to run in output allocator.
83+
requires_output_allocator: Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators).
8484
"""
8585

8686
converter_implementation: ConverterImplSignature
@@ -215,7 +215,7 @@ def dynamo_tensorrt_converter(
215215
priority: Converter's level of priority relative to other converters with the
216216
same target
217217
supports_dynamic_shapes: Boolean flag indicating if the converter has support for dynamic shapes.
218-
requires_output_allocator: Boolean flag indicating if the converter requires to run in output allocator.
218+
requires_output_allocator: Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators).
219219
Returns:
220220
The converter being decorated
221221
"""
@@ -410,7 +410,7 @@ def __getitem_without_validation__(
410410
def __getitem__(
411411
self, node: Node
412412
) -> Tuple[
413-
Any, CallingConvention, bool
413+
Any, CallingConvention, Dict[str, bool]
414414
]: # TODO: Narrow to ConverterImplSignature this when we can remove FX converters
415415
"""Get the first-found validated converter in any registry
416416
@@ -468,7 +468,10 @@ def __getitem__(
468468
return (
469469
candidate.converter_implementation,
470470
calling_convention,
471-
candidate.requires_output_allocator,
471+
{
472+
"supports_dynamic_shapes": candidate.supports_dynamic_shapes,
473+
"requires_output_allocator": candidate.requires_output_allocator,
474+
},
472475
)
473476
else:
474477
logger.debug(
@@ -481,7 +484,10 @@ def __getitem__(
481484
return (
482485
converters,
483486
calling_convention,
484-
False,
487+
{
488+
"supports_dynamic_shapes": False,
489+
"requires_output_allocator": False,
490+
},
485491
)
486492

487493
raise KeyError(
@@ -506,7 +512,7 @@ def get_unvalidated(
506512
def get(
507513
self, node: Node, value: Optional[ConverterImplSignature] = None
508514
) -> Union[
509-
Any, Tuple[Any, CallingConvention, bool]
515+
Any, Tuple[Any, CallingConvention, Dict[str, bool]]
510516
]: # TODO: Narrow to ConverterImplSignature this when we can remove FX converters
511517
"""Get validated converter for input node with a default return"""
512518
try:

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ def call_module(
835835
f"Conversion of module of type {submod_type} not currently supported!"
836836
)
837837

838-
converter, calling_convention, requires_output_allocator = converter_packet
838+
converter, calling_convention, _ = converter_packet
839839

840840
assert self._cur_node_name is not None
841841

@@ -852,8 +852,8 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
852852
f"Conversion of function {torch.typename(target)} not currently supported!"
853853
)
854854

855-
converter, calling_convention, requires_output_allocator = converter_packet
856-
if requires_output_allocator:
855+
converter, calling_convention, converter_info = converter_packet
856+
if converter_info.get("requires_output_allocator", False):
857857
self.ctx.requires_output_allocator = True
858858
_LOGGER.debug(f"{target} requires output allocator")
859859

@@ -885,7 +885,7 @@ def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
885885
raise UnsupportedOperatorException(
886886
f"Conversion of method {target} not currently supported!"
887887
)
888-
converter, calling_convention, requires_output_allocator = converter_packet
888+
converter, calling_convention, _ = converter_packet
889889

890890
if calling_convention is CallingConvention.LEGACY:
891891
return converter(self.ctx.net, target, args, kwargs, self._cur_node_name)

py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,15 @@ def remove_num_users_is_0_nodes(
1313
gm: torch.fx.GraphModule, settings: CompilationSettings
1414
) -> torch.fx.GraphModule:
1515
"""Remove ops that [num_users=0] in the graph"""
16-
output_node = list(gm.graph.nodes)[-1]
16+
nodes = list(gm.graph.nodes)
17+
output_node = nodes[-1]
1718

18-
for node in gm.graph.nodes:
19+
for node in nodes[::-1]:
1920
if (
2021
node != output_node
2122
and len(node.users) == 0
2223
and len(node.all_input_nodes) > 0
2324
):
24-
node_input = node.all_input_nodes[0]
25-
node.replace_all_uses_with(node_input)
2625
gm.graph.erase_node(node)
2726
gm = clean_up_graph_after_modifications(gm)
2827

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __init__(
141141
name (str): Name for module
142142
settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed
143143
weight_name_map (dict): Mapping of engine weight name to state_dict weight name
144-
requires_output_allocator (bool): Whether the engine requires an output allocator
144+
requires_output_allocator (bool): Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators)
145145
146146
Example:
147147

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(
9898
name (str): Name for module
9999
settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed
100100
weight_name_map (dict): Mapping of engine weight name to state_dict weight name
101-
requires_output_allocator (bool): Whether the engine requires an output allocator
101+
requires_output_allocator (bool): Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators)
102102
103103
Example:
104104

py/torch_tensorrt/runtime/_cudagraphs.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -74,20 +74,21 @@ def __enter__(self) -> torch.nn.Module:
7474

7575
num_torch_module = 0
7676
num_trt_module = 0
77-
disable_cudagraphs = False
7877
for name, module in self.compiled_module.named_children():
79-
# disable cudagraphs if any model requires output allocator
78+
# need to disable cudagraphs if any model requires output allocator
8079
if (
8180
hasattr(module, "requires_output_allocator")
8281
and module.requires_output_allocator
8382
):
84-
disable_cudagraphs = True
83+
raise RuntimeError(
84+
"There are converters that require Output Allocator. Please disable CUDA Graphs."
85+
)
8586
if "_run_on_acc" in name:
8687
num_trt_module += 1
8788
elif "_run_on_gpu" in name:
8889
num_torch_module += 1
8990

90-
if num_torch_module > 0 and not disable_cudagraphs:
91+
if num_torch_module > 0:
9192
# Set whole cudagraphs mode and returns wrapped module
9293
_PY_RT_CUDAGRAPHS = CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS
9394
# Set new mode for C++

0 commit comments

Comments
 (0)