Skip to content

Commit b57d83e

Browse files
authored
feat: Improve Dynamo partitioning System Performance on Large Models (pytorch#2175)
1 parent 32d905b commit b57d83e

File tree

15 files changed

+634
-146
lines changed

15 files changed

+634
-146
lines changed

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
OPTIMIZATION_LEVEL = None
1111
TRUNCATE_LONG_AND_DOUBLE = False
1212
USE_PYTHON_RUNTIME = False
13+
USE_FAST_PARTITIONER = True

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
PASS_THROUGH_BUILD_FAILURES,
1111
PRECISION,
1212
TRUNCATE_LONG_AND_DOUBLE,
13+
USE_FAST_PARTITIONER,
1314
USE_PYTHON_RUNTIME,
1415
VERSION_COMPATIBLE,
1516
WORKSPACE_SIZE,
@@ -29,3 +30,4 @@ class CompilationSettings:
2930
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
3031
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
3132
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
33+
use_fast_partitioner: bool = USE_FAST_PARTITIONER

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77
import torch
88
import torch._dynamo as td
99
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
10-
from torch_tensorrt.dynamo import CompilationSettings
10+
from torch_tensorrt.dynamo import CompilationSettings, partitioning
1111
from torch_tensorrt.dynamo.conversion import (
1212
convert_module,
1313
repair_long_or_double_inputs,
1414
)
1515
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
16-
from torch_tensorrt.dynamo.lowering._partition import get_submod_inputs, partition
1716
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1817
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
1918

@@ -111,24 +110,68 @@ def _compile_module(
111110
Returns:
112111
Compiled FX GraphModule
113112
"""
114-
# Partition module into components that can be TRT-accelerated
115-
partitioned_module = partition(
116-
gm,
117-
verbose=settings.debug,
118-
min_block_size=settings.min_block_size,
119-
torch_executed_ops=settings.torch_executed_ops,
113+
# Check the number of supported operations in the graph
114+
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
115+
gm, settings.debug, settings.torch_executed_ops
120116
)
121117

118+
# If the number of supported operations is 0 or less than the block size, skip the subgraph
119+
# TODO: Add condition to second expression below when require_full_compilation is added
120+
if num_supported_ops == 0 or (num_supported_ops < settings.min_block_size):
121+
logger.warning(
122+
f"{num_supported_ops} supported operations detected in subgraph containing {total_ops} computational nodes. "
123+
f"Skipping this subgraph, since min_block_size was detected to be {settings.min_block_size}"
124+
)
125+
return gm
126+
else:
127+
logger.debug(
128+
f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph."
129+
)
130+
131+
# Partition module into components that can be TRT-accelerated
132+
fast_partitioner_failed = False
133+
134+
# If specified, try using the fast partitioner and fall back to the global one on failure
135+
if settings.use_fast_partitioner:
136+
try:
137+
partitioned_module = partitioning.fast_partition(
138+
gm,
139+
verbose=settings.debug,
140+
min_block_size=settings.min_block_size,
141+
torch_executed_ops=settings.torch_executed_ops,
142+
)
143+
except torch.fx.passes.splitter_base.FxNetSplitterInternalError:
144+
logger.error(
145+
"Partitioning failed on the subgraph with fast partition. See trace above. "
146+
+ "Retrying with global partition.",
147+
exc_info=True,
148+
)
149+
150+
fast_partitioner_failed = True
151+
settings.use_fast_partitioner = False
152+
153+
if not settings.use_fast_partitioner:
154+
partitioned_module = partitioning.global_partition(
155+
gm,
156+
verbose=settings.debug,
157+
min_block_size=settings.min_block_size,
158+
torch_executed_ops=settings.torch_executed_ops,
159+
)
160+
122161
# Store TRT replicas of Torch subgraphs
123162
trt_modules = {}
124163

125164
# Iterate over all components that can be accelerated
126165
# Generate the corresponding TRT Module for those
127166
for name, _ in partitioned_module.named_children():
167+
# Criteria for a module to be convertible to TRT
168+
if settings.use_fast_partitioner and "_run_on_acc" not in name:
169+
continue
170+
128171
submodule = getattr(partitioned_module, name)
129172

130173
# Get submodule inputs
131-
submodule_inputs = get_submod_inputs(
174+
submodule_inputs = partitioning.get_submod_inputs(
132175
partitioned_module, submodule, sample_inputs
133176
)
134177

@@ -153,4 +196,8 @@ def _compile_module(
153196
for name, trt_mod in trt_modules.items():
154197
setattr(partitioned_module, name, trt_mod)
155198

199+
# Reset settings object to user specification after fallback to global partitioning mode
200+
if fast_partitioner_failed:
201+
settings.use_fast_partitioner = True
202+
156203
return partitioned_module

py/torch_tensorrt/dynamo/compile.py

Lines changed: 5 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch
88
import torch_tensorrt
99
from torch.fx.passes.pass_manager import PassManager
10-
from torch.fx.passes.splitter_base import SplitResult
1110
from torch_tensorrt._Device import Device
1211
from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
1312
EngineCapability,
@@ -21,18 +20,17 @@
2120
PASS_THROUGH_BUILD_FAILURES,
2221
PRECISION,
2322
TRUNCATE_LONG_AND_DOUBLE,
23+
USE_FAST_PARTITIONER,
2424
USE_PYTHON_RUNTIME,
2525
VERSION_COMPATIBLE,
2626
WORKSPACE_SIZE,
2727
)
2828
from torch_tensorrt.dynamo.backend.backends import _compile_module
29-
from torch_tensorrt.dynamo.conversion import convert_module
3029
from torch_tensorrt.dynamo.lowering._fusers import (
3130
fuse_permute_linear,
3231
fuse_permute_matmul,
3332
)
3433
from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs
35-
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting
3634

3735
logger = logging.getLogger(__name__)
3836

@@ -64,6 +62,7 @@ def compile(
6462
version_compatible: bool = VERSION_COMPATIBLE,
6563
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
6664
use_python_runtime: bool = USE_PYTHON_RUNTIME,
65+
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
6766
**kwargs: Any,
6867
) -> torch.fx.GraphModule:
6968
if debug:
@@ -75,7 +74,7 @@ def compile(
7574
"The Dynamo backend is an experimental feature, for which only the "
7675
+ "following arguments are supported: "
7776
+ "{enabled_precisions, debug, workspace_size, min_block_size, "
78-
+ "torch_executed_ops, pass_through_build_failures}"
77+
+ "torch_executed_ops, pass_through_build_failures, use_fast_partitioner}"
7978
)
8079

8180
if not isinstance(inputs, collections.abc.Sequence):
@@ -115,55 +114,12 @@ def compile(
115114
"optimization_level": optimization_level,
116115
"use_python_runtime": use_python_runtime,
117116
"truncate_long_and_double": truncate_long_and_double,
117+
"use_fast_partitioner": use_fast_partitioner,
118118
}
119119

120120
settings = CompilationSettings(**compilation_options)
121-
if kwargs.get("use_capability_partitioner", None):
122-
model = lower_model(gm, torch_inputs)
123-
return _compile_module(model, torch_inputs, settings)
124-
else:
125-
split_result = lower_model_using_trt_splitter(gm, torch_inputs)
126-
trt_module = _compile_graph(split_result, torch_inputs, settings)
127-
128-
return trt_module
129121

130-
131-
def _compile_graph(
132-
split_result: SplitResult,
133-
inputs: Any,
134-
settings: CompilationSettings = CompilationSettings(),
135-
**kwargs: Any,
136-
) -> torch.fx.GraphModule:
137-
for submod_name, submod_inputs in split_result.submodule_inputs.items():
138-
submod = getattr(split_result.split_module, submod_name)
139-
# Only acc submodules will be lowered.
140-
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
141-
# Create TRT Module from submodule
142-
trt_mod = convert_module(
143-
submod,
144-
submod_inputs,
145-
settings=settings,
146-
name=submod_name,
147-
)
148-
setattr(split_result.split_module, submod_name, trt_mod)
149-
150-
return split_result.split_module
151-
152-
153-
def lower_model_using_trt_splitter(
154-
model: torch.nn.Module, inputs: Any, **kwargs: Any
155-
) -> SplitResult:
156-
# Perform basic lowering
157-
model = lower_model(model, inputs)
158-
splitter_setting = TRTSplitterSetting()
159-
splitter_setting.use_implicit_batch_dim = False
160-
splitter_setting.min_acc_module_size = 1
161-
splitter_setting.use_experimental_rt = False
162-
splitter = TRTSplitter(model, inputs, settings=splitter_setting)
163-
splitter.node_support_preview()
164-
split_result = splitter.generate_split_results()
165-
166-
return split_result
122+
return _compile_module(gm, torch_inputs, settings)
167123

168124

169125
def lower_model(

py/torch_tensorrt/dynamo/conversion/converter_registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,8 @@ def unique_targets(self) -> Set[Target]:
349349
"""Returns the set of unique converter targets stored across all registries"""
350350
return set.union(*[set(registry.keys()) for registry in self.registries])
351351

352-
# TODO: Make this a static method since it does not need state
353-
def qualified_name_or_str(self, target: Target) -> str:
352+
@staticmethod
353+
def qualified_name_or_str(target: Target) -> str:
354354
"""Returns string representation of an FX Node target"""
355355
if isinstance(target, str):
356356
return target
Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
11
from ._decompositions import get_decompositions # noqa: F401
2-
from ._fusers import * # noqa: F403
3-
from ._partition import ( # noqa: F401
4-
DEFAULT_SINGLE_NODE_PARTITIONS,
5-
get_submod_inputs,
6-
partition,
7-
)
2+
from ._fusers import * # noqa: F401
83
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
94
from ._pre_aot_lowering import register_substitution # noqa: F401
10-
from .substitutions import * # noqa: F403
5+
from .substitutions import * # noqa: F401
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from ._adjacency_partitioner import partition as fast_partition
2+
from ._global_partitioner import partition as global_partition
3+
from .common import get_graph_converter_support, get_submod_inputs

0 commit comments

Comments
 (0)