Skip to content

Commit b56b2ae

Browse files
committed
feat: exclude refit sensitive ops from TRT compilation (#3159)
1 parent 6717ddc commit b56b2ae

13 files changed

+174
-56
lines changed

py/torch_tensorrt/dynamo/_compiler.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -314,11 +314,9 @@ def compile_module(
314314
dryrun_tracker = DryRunTracker()
315315
if sample_kwarg_inputs is None:
316316
sample_kwarg_inputs = {}
317-
# Assume converters support dynamic shapes and disable validation
318-
CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support)
319317

320-
# Set torch-executed ops
321-
CONVERTERS.set_disallowed_targets(settings.torch_executed_ops)
318+
# Configure user compilation settings to converters.
319+
CONVERTERS.set_compilation_settings(settings)
322320

323321
# Check the number of supported operations in the graph
324322
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
@@ -670,8 +668,8 @@ def convert_exported_program_to_serialized_trt_engine(
670668
settings = CompilationSettings(**compilation_options)
671669
logger.info("Compilation Settings: %s\n", settings)
672670

673-
# Assume converters support dynamic shapes and disable validation
674-
CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support)
671+
# Configure user compilation settings to converters.
672+
CONVERTERS.set_compilation_settings(settings)
675673

676674
try:
677675
interpreter_result = interpret_module_to_result(

py/torch_tensorrt/dynamo/_refit.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, List, Optional, Sequence, Tuple
77

88
import numpy as np
9+
import tensorrt as trt
910
import torch
1011
from torch.export import ExportedProgram
1112
from torch_tensorrt._enums import dtype
@@ -42,8 +43,6 @@
4243
)
4344
from torch_tensorrt.logging import TRT_LOGGER
4445

45-
import tensorrt as trt
46-
4746
logger = logging.getLogger(__name__)
4847

4948

py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py

+27-13
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torch import SymBool, SymFloat, SymInt
2424
from torch._ops import OpOverloadPacket
2525
from torch.fx.node import Argument, Node, Target, _get_qualified_name
26+
from torch_tensorrt.dynamo._settings import CompilationSettings
2627
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
2728
from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS
2829

@@ -82,7 +83,9 @@ class ConverterSupport:
8283
"""
8384

8485
converter_implementation: ConverterImplSignature
85-
capability_validator: Callable[[Node], bool] = field(default=lambda node: True)
86+
capability_validator: Callable[[Node, CompilationSettings], bool] = field(
87+
default=lambda node, compilation_settings: True
88+
)
8689
supports_dynamic_shapes: bool = False
8790

8891

@@ -112,18 +115,20 @@ def has_dynamic_shapes_in_args(
112115

113116
def has_static_shapes_in_args(
114117
arg_positions_to_check: Optional[List[int]] = None,
115-
) -> Callable[[torch.fx.Node], bool]:
118+
) -> Callable[[torch.fx.Node, CompilationSettings], bool]:
116119
"""Returns True if a node has static inputs in node.args at specified positions"""
117-
_has_static_shapes = lambda node, arg_positions_to_check: not _has_dynamic_shapes(
118-
node, arg_positions_to_check
120+
_has_static_shapes = lambda node, compilation_settings, arg_positions_to_check: not _has_dynamic_shapes(
121+
node, compilation_settings, arg_positions_to_check
119122
)
120123
return functools.partial(
121124
_has_static_shapes, arg_positions_to_check=arg_positions_to_check
122125
)
123126

124127

125128
def _has_dynamic_shapes(
126-
node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None
129+
node: torch.fx.Node,
130+
compilation_settings: CompilationSettings = None,
131+
arg_positions_to_check: Optional[List[int]] = None,
127132
) -> bool:
128133
# Validate that none of the inputs to the node have Dynamic shapes
129134
assert isinstance(
@@ -188,7 +193,7 @@ def dynamo_tensorrt_converter(
188193
key: Target,
189194
*,
190195
enabled: bool = True,
191-
capability_validator: Optional[Callable[[Node], bool]] = None,
196+
capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None,
192197
priority: ConverterPriority = ConverterPriority.STANDARD,
193198
supports_dynamic_shapes: bool = False,
194199
) -> Callable[[ConverterImplSignature], ConverterImplSignature]:
@@ -297,7 +302,6 @@ def __init__(
297302
],
298303
registry_names: Optional[Sequence[str]] = None,
299304
registry_calling_conventions: Optional[Sequence[CallingConvention]] = None,
300-
assume_dynamic_shape_support: bool = False,
301305
):
302306
# Copy reference to each dictionary object into attribute list
303307
self.registries = list(registries)
@@ -318,12 +322,16 @@ def __init__(
318322
CallingConvention.CTX for _ in range(len(self.registries))
319323
]
320324

325+
self.compilation_settings: CompilationSettings = None
321326
self.disallowed_targets: Collection[Target] = set()
322-
self.assume_dynamic_shape_support = assume_dynamic_shape_support
323327
self.validate_invariants()
324328

325-
def set_dynamic_shape_support(self, assume_dynamic_shape_support: bool) -> None:
326-
self.assume_dynamic_shape_support = assume_dynamic_shape_support
329+
def set_compilation_settings(
330+
self, compilation_settings: CompilationSettings
331+
) -> None:
332+
self.compilation_settings = compilation_settings
333+
# set torch executed ops as disallowed targets
334+
self.set_disallowed_targets(compilation_settings.torch_executed_ops)
327335

328336
def set_disallowed_targets(self, torch_executed_ops: Collection[Target]) -> None:
329337
self.disallowed_targets = torch_executed_ops
@@ -412,7 +420,11 @@ def __getitem__(
412420

413421
self.validate_invariants()
414422
key = node.target
415-
423+
assume_dynamic_shape_support = False
424+
if self.compilation_settings:
425+
assume_dynamic_shape_support = (
426+
self.compilation_settings.assume_dynamic_shape_support
427+
)
416428
if (
417429
key in self.disallowed_targets
418430
or self.qualified_name_or_str(key) in self.disallowed_targets
@@ -436,8 +448,10 @@ def __getitem__(
436448
# 2) Assume dynamic_shape support is True
437449
# 3) Node only has static shaped inputs
438450
# 4) Node has dynamic inputs and the converter has supports_dynamic_shapes=True
439-
if candidate.capability_validator(node) and (
440-
self.assume_dynamic_shape_support
451+
if candidate.capability_validator(
452+
node, self.compilation_settings
453+
) and (
454+
assume_dynamic_shape_support
441455
or not node_has_dynamic_shapes(node)
442456
or candidate.supports_dynamic_shapes
443457
):

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919

2020
import numpy as np
21+
import tensorrt as trt
2122
import torch
2223
import torch.fx
2324
from torch.fx.node import _get_qualified_name
@@ -43,7 +44,6 @@
4344
from torch_tensorrt.fx.observer import Observer
4445
from torch_tensorrt.logging import TRT_LOGGER
4546

46-
import tensorrt as trt
4747
from packaging import version
4848

4949
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -89,6 +89,11 @@ def __init__(
8989
self.builder.create_network(flag), compilation_settings
9090
)
9191

92+
self.compilation_settings = compilation_settings
93+
if not CONVERTERS.compilation_settings:
94+
# Configure user compilation settings to converters.
95+
CONVERTERS.set_compilation_settings(compilation_settings)
96+
9297
assert TRTInterpreter._all_precisions_supported(
9398
compilation_settings.enabled_precisions
9499
), f"Attempted to enable kernel precisions that are not supported (got: {compilation_settings.enabled_precisions}, support: {_defaults.SUPPORTED_KERNEL_PRECISIONS})"
@@ -117,7 +122,6 @@ def __init__(
117122
self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
118123
dict()
119124
)
120-
self.compilation_settings = compilation_settings
121125

122126
# Data types for TRT Module output Tensors
123127
self.output_dtypes = (

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+47-19
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import torch
99
from torch.fx.node import Argument, Node, Target
10+
from torch_tensorrt.dynamo._settings import CompilationSettings
1011
from torch_tensorrt.dynamo._SourceIR import SourceIR
1112
from torch_tensorrt.dynamo.conversion import impl
1213
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -48,7 +49,7 @@ def get_ir(target: Target) -> SourceIR:
4849
return SourceIR.UNKNOWN
4950

5051

51-
def one_user_validator(node: Node) -> bool:
52+
def one_user_validator(node: Node, settings: CompilationSettings = None) -> bool:
5253
# Validate only one user, which is a getitem node that accesses the first element in the list
5354
return (
5455
len(node.users) == 1
@@ -270,7 +271,11 @@ def aten_ops_embedding(
270271
)
271272

272273

273-
def embedding_bag_validator(node: Node) -> bool:
274+
def embedding_bag_validator(node: Node, settings: CompilationSettings = None) -> bool:
275+
# Embedding bag op is not refitable
276+
if settings.make_refittable:
277+
return False
278+
274279
if not one_user_validator(node):
275280
return False
276281
meta = node.args[1].meta
@@ -416,7 +421,7 @@ def aten_ops_symsize_int(
416421
return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1])
417422

418423

419-
def index_dtype_validator(node: Node) -> bool:
424+
def index_dtype_validator(node: Node, settings: CompilationSettings = None) -> bool:
420425
index = node.args[1]
421426
for ind in index:
422427
if ind is not None:
@@ -837,7 +842,7 @@ def aten_ops_select(
837842
)
838843

839844

840-
def index_put_validator(node: Node) -> bool:
845+
def index_put_validator(node: Node, settings: CompilationSettings = None) -> bool:
841846
if args_bounds_check(node.args, 3, False): # Check if accumulate is valid
842847
_LOGGER.debug("We do not support accumulate=True for aten.index_put operation")
843848
accumulate_valid = False
@@ -924,7 +929,18 @@ def aten_ops_slice(
924929
)
925930

926931

927-
@dynamo_tensorrt_converter(torch.ops.aten.cumsum.default, supports_dynamic_shapes=True)
932+
def refit_validator(node: Node, settings: CompilationSettings = None) -> bool:
933+
# cumsum op is not refitable
934+
if settings and settings.make_refittable:
935+
return False
936+
return True
937+
938+
939+
@dynamo_tensorrt_converter(
940+
torch.ops.aten.cumsum.default,
941+
capability_validator=refit_validator,
942+
supports_dynamic_shapes=True,
943+
)
928944
@enforce_tensor_types(
929945
{
930946
0: (TRTTensor,),
@@ -970,7 +986,7 @@ def aten_ops_tile(
970986
)
971987

972988

973-
def zero_output_validator(node: Node) -> bool:
989+
def zero_output_validator(node: Node, settings: CompilationSettings = None) -> bool:
974990
if 0 in node.args[1]:
975991
_LOGGER.debug(
976992
f"We do not support output tensor {node.args[1]} tensors with zero-sized dimensions for this operation."
@@ -1027,7 +1043,9 @@ def aten_ops_permute(
10271043
)
10281044

10291045

1030-
def to_copy_dtype_validator(placeholder_only: bool) -> Callable[[Node], bool]:
1046+
def to_copy_dtype_validator(
1047+
placeholder_only: bool, settings: CompilationSettings = None
1048+
) -> Callable[[Node, CompilationSettings], bool]:
10311049
"""Return validator for to_copy node with placeholder restrictions"""
10321050

10331051
def validate_dtype(to_copy_node: Node) -> bool:
@@ -1059,7 +1077,7 @@ def validate_dtype(to_copy_node: Node) -> bool:
10591077
)
10601078
return False
10611079

1062-
def validator(to_copy_node: Node) -> bool:
1080+
def validator(to_copy_node: Node, settings: CompilationSettings = None) -> bool:
10631081
"""Returns true if the to_copy node can be converted to TRT
10641082
and the placeholder restriction is satisfied
10651083
"""
@@ -1074,7 +1092,9 @@ def validator(to_copy_node: Node) -> bool:
10741092

10751093
@dynamo_tensorrt_converter(
10761094
torch.ops.aten.clone.default,
1077-
capability_validator=lambda node: not is_only_operator_on_placeholder(node),
1095+
capability_validator=lambda node, settings: not is_only_operator_on_placeholder(
1096+
node, settings
1097+
),
10781098
supports_dynamic_shapes=True,
10791099
)
10801100
@dynamo_tensorrt_converter(
@@ -2128,7 +2148,7 @@ def aten_ops_logical_xor(
21282148
)
21292149

21302150

2131-
def bitwise_type_validator(node: Node) -> bool:
2151+
def bitwise_type_validator(node: Node, settings: CompilationSettings = None) -> bool:
21322152
supported_type = [torch.bool, bool]
21332153

21342154
tensor_targets = [
@@ -2271,7 +2291,9 @@ def aten_ops_bitwise_xor(
22712291
)
22722292

22732293

2274-
def bitwise_not_type_validator(node: Node) -> bool:
2294+
def bitwise_not_type_validator(
2295+
node: Node, settings: CompilationSettings = None
2296+
) -> bool:
22752297
val = node.args[0]
22762298
val_meta = val.meta.get("tensor_meta")
22772299

@@ -2453,7 +2475,7 @@ def aten_ops_le(
24532475
)
24542476

24552477

2456-
def conv_param_validator(conv_node: Node) -> bool:
2478+
def conv_param_validator(conv_node: Node, settings: CompilationSettings = None) -> bool:
24572479
return conv_node.args[7] in ([0], [0, 0], [0, 0, 0])
24582480

24592481

@@ -2549,7 +2571,9 @@ def aten_ops_cdist_forward(
25492571
)
25502572

25512573

2552-
def avg_pool_param_validator(pool_node: Node) -> bool:
2574+
def avg_pool_param_validator(
2575+
pool_node: Node, settings: CompilationSettings = None
2576+
) -> bool:
25532577
ceil_mode = args_bounds_check(pool_node.args, 4, False)
25542578
divisor_override = args_bounds_check(pool_node.args, 6)
25552579

@@ -2665,12 +2689,12 @@ def aten_ops_adaptive_avg_poolNd(
26652689
)
26662690

26672691

2668-
def topk_validator(node: Node) -> bool:
2692+
def topk_validator(node: Node, settings: CompilationSettings = None) -> bool:
26692693
k = node.args[1]
26702694
return topk_sort_validator(k)
26712695

26722696

2673-
def sort_validator(node: Node) -> bool:
2697+
def sort_validator(node: Node, settings: CompilationSettings = None) -> bool:
26742698
meta_data = node.args[0].meta.get("tensor_meta")
26752699
if meta_data is None:
26762700
return False
@@ -2692,7 +2716,9 @@ def topk_sort_validator(k: int) -> bool:
26922716
return True
26932717

26942718

2695-
def max_pool_param_validator(pool_node: Node) -> bool:
2719+
def max_pool_param_validator(
2720+
pool_node: Node, settings: CompilationSettings = None
2721+
) -> bool:
26962722
dilation = args_bounds_check(pool_node.args, 4, 1)
26972723
ceil_mode = args_bounds_check(pool_node.args, 5, False)
26982724

@@ -2746,7 +2772,7 @@ def aten_ops_max_pool(
27462772
)
27472773

27482774

2749-
def attention_validator(node: Node) -> bool:
2775+
def attention_validator(node: Node, settings: CompilationSettings = None) -> bool:
27502776
# Currently, `attn_mask` is not supported
27512777
return args_bounds_check(node.args, 3) is None
27522778

@@ -3637,7 +3663,7 @@ def aten_ops_flip(
36373663
)
36383664

36393665

3640-
def zero_diag_size_validator(node: Node) -> bool:
3666+
def zero_diag_size_validator(node: Node, settings: CompilationSettings = None) -> bool:
36413667
meta = node.args[0].meta.get("tensor_meta")
36423668
if meta:
36433669
input_shape = meta.shape
@@ -3765,7 +3791,9 @@ def aten_ops_index_select(
37653791
)
37663792

37673793

3768-
def dropout_inference_validator(node: Node) -> bool:
3794+
def dropout_inference_validator(
3795+
node: Node, settings: CompilationSettings = None
3796+
) -> bool:
37693797
train_mode = args_bounds_check(node.args, 2, None)
37703798
if train_mode is False:
37713799
return True

0 commit comments

Comments
 (0)