Skip to content

fix: Improve partitioning + lowering systems in torch.compile path #1879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions py/torch_tensorrt/dynamo/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch_tensorrt
from functools import partial

from typing import Any
from typing import Any, Sequence
from torch_tensorrt import EngineCapability, Device
from torch_tensorrt.fx.utils import LowerPrecision

Expand All @@ -15,7 +15,7 @@
PRECISION,
DEBUG,
MAX_WORKSPACE_SIZE,
MAX_NUM_TRT_ENGINES,
MIN_BLOCK_SIZE,
)


Expand All @@ -41,7 +41,7 @@ def compile(
calibrator=None,
truncate_long_and_double=False,
require_full_compilation=False,
min_block_size=3,
min_block_size=MIN_BLOCK_SIZE,
torch_executed_ops=[],
torch_executed_modules=[],
**kwargs,
Expand All @@ -50,7 +50,7 @@ def compile(
logger.warn(
"The Dynamo backend is an experimental feature, for which only the "
+ "following arguments are supported: "
+ "{enabled_precisions, debug, workspace_size, max_num_trt_engines}"
+ "{enabled_precisions, debug, workspace_size, min_block_size, torch_executed_ops}"
)

if not isinstance(inputs, collections.abc.Sequence):
Expand Down Expand Up @@ -80,6 +80,8 @@ def compile(
precision=lower_precision,
debug=debug,
workspace_size=workspace_size,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
**kwargs,
)

Expand All @@ -100,7 +102,8 @@ def create_backend(
precision: LowerPrecision = PRECISION,
debug: bool = DEBUG,
workspace_size: int = MAX_WORKSPACE_SIZE,
max_num_trt_engines: int = MAX_NUM_TRT_ENGINES,
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Sequence[str] = set(),
**kwargs,
):
"""Create torch.compile backend given specified arguments
Expand All @@ -117,7 +120,8 @@ def create_backend(
debug=debug,
precision=precision,
workspace_size=workspace_size,
max_num_trt_engines=max_num_trt_engines,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
)

return partial(
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/backend/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
PRECISION = LowerPrecision.FP32
DEBUG = False
MAX_WORKSPACE_SIZE = 20 << 30
MAX_NUM_TRT_ENGINES = 10
MIN_BLOCK_SIZE = 5
8 changes: 5 additions & 3 deletions py/torch_tensorrt/dynamo/backend/_settings.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Sequence

from torch_tensorrt.fx.utils import LowerPrecision
from torch_tensorrt.dynamo.backend._defaults import (
PRECISION,
DEBUG,
MAX_WORKSPACE_SIZE,
MAX_NUM_TRT_ENGINES,
MIN_BLOCK_SIZE,
)


Expand All @@ -14,4 +15,5 @@ class CompilationSettings:
precision: LowerPrecision = PRECISION
debug: bool = DEBUG
workspace_size: int = MAX_WORKSPACE_SIZE
max_num_trt_engines: int = MAX_NUM_TRT_ENGINES
min_block_size: int = MIN_BLOCK_SIZE
torch_executed_ops: Sequence[str] = field(default_factory=set)
5 changes: 4 additions & 1 deletion py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def _compile_module(
"""
# Partition module into components that can be TRT-accelerated
partitioned_module = partition(
gm, verbose=settings.debug, max_num_trt_engines=settings.max_num_trt_engines
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
)

# Iterate over all components that can be accelerated
Expand Down
15 changes: 15 additions & 0 deletions py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,20 @@ def inplace_op(*args, **kwargs):
replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce)


@register_decomposition(aten.std, registry=DECOMPOSITIONS)
def std_replacement(*args, **kwargs) -> torch.Tensor:
return torch.sqrt(torch.var(*args, **kwargs))


@register_decomposition(aten.rsqrt, registry=DECOMPOSITIONS)
def rsqrt_replacement(*args, **kwargs) -> torch.Tensor:
return torch.reciprocal(torch.sqrt(*args, **kwargs))


@register_decomposition(aten.alias, registry=DECOMPOSITIONS)
def alias_replacement(x: torch.Tensor) -> torch.Tensor:
return x


def get_decompositions():
return DECOMPOSITIONS
151 changes: 120 additions & 31 deletions py/torch_tensorrt/dynamo/backend/lowering/_partition.py
Original file line number Diff line number Diff line change
@@ -1,92 +1,181 @@
from typing import Dict, Optional, Sequence
import logging
from typing import Dict, List, Optional, Sequence

import torch

from torch_tensorrt.dynamo.backend._defaults import MAX_NUM_TRT_ENGINES
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
from torch.fx.graph_module import GraphModule
from torch.fx.node import _get_qualified_name
from torch.fx.passes.operator_support import OperatorSupport

from torch_tensorrt.fx.converter_registry import CONVERTERS


logger = logging.getLogger(__name__)


class TRTPartitioner(CapabilityBasedPartitioner):
"""Partitioner to split an FX graph into subgraphs based on operator support

Args:
graph_module: FX GraphModule to partition
operator_support: OperatorSupport class describing allowed operators
non_compute_ops: Operators which are not considered computational (e.g. getattr)
allowed_single_node_partition_ops: Nodes which can be included in single-node partitons.
Generally useful for module-level exclusion ops which are intensive despite being single functions
min_block_size: Minimum number of computational operators per block
Returns:
torch.fx.GraphModule
"""

def __init__(
self,
graph_module: GraphModule,
operator_support: OperatorSupport,
*,
non_compute_ops: Optional[Sequence[str]] = None,
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
min_block_size=MIN_BLOCK_SIZE,
) -> None:
super().__init__(
graph_module,
operator_support,
allows_single_node_partition=True,
non_compute_ops=non_compute_ops,
allowed_single_node_partition_ops=allowed_single_node_partition_ops,
)

self.min_block_size = min_block_size

def propose_partitions(self) -> List[Partition]:
# Propose partitions using the default, then refine the results
initial_proposed_partitions = super().propose_partitions()
partitions = {i: part for i, part in enumerate(initial_proposed_partitions)}

# For each partition, determine whether or not the number of computational operators
# exceeds the threshold, and if not, remove that partition
partitions_to_remove = {}
for id, partition in partitions.items():
default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops))
exempted_partition = False

compute_node_count = 0
for node in partition.nodes:
# Partitions are exempted from min_block_size if they contain an allowed single-node op
if (
node.op == "call_function"
and _get_qualified_name(node.target)
in self.allowed_single_node_partition_ops
):
exempted_partition = True
break
elif (
node.op == "call_function"
and _get_qualified_name(node.target) not in non_compute_ops
):
compute_node_count += 1

if compute_node_count < self.min_block_size and not exempted_partition:
partitions_to_remove[id] = compute_node_count

# Remove any nodes violating the criteria specified by the user
for id, count in partitions_to_remove.items():
logger.debug(
f"Removing partition which has {count} < {self.min_block_size} computational operators"
)
del partitions[id]

return [partitions[k] for k in sorted(partitions.keys())]

def partition_and_fuse(self) -> GraphModule:
partitions = self.propose_partitions()
fused_gm = self.fuse_partitions(partitions)
return fused_gm


class TorchTensorRTOperatorSupport(OperatorSupport):
"""Class to determine whether operators within a module are supported"""

def __init__(self, support_dict=None):
def __init__(self, support_dict=None, torch_executed_ops=set()):
super().__init__(support_dict)

# Initialize sets of supported/unsupported operators
self.supported_operators = set()
self.unsupported_operators = set()
self.torch_executed_ops = torch_executed_ops

def is_node_supported(
self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node
) -> bool:
if node.target in CONVERTERS.keys():
# If node is a proper computational node, store the operator
node_name = (
_get_qualified_name(node.target)
if not isinstance(node.target, str)
else node.target
)

if (
node.target in CONVERTERS.keys()
and node_name not in self.torch_executed_ops
):
# If node is a proper, supported computational node, store the operator
if not node.is_impure():
node_name = node._pretty_print_target(node.target)
self.supported_operators.add(node_name)

return True
else:
if not node.is_impure():
node_name = node._pretty_print_target(node.target)
self.unsupported_operators.add(node_name)

return False

def print_support_overview(self, num_trt_blocks: Optional[int] = None):
if num_trt_blocks is not None:
print(f"\nNumber of TensorRT-Accelerated Subgraphs: {num_trt_blocks}")
logger.debug(
f"\nNumber of TensorRT-Accelerated Engines Generated: {num_trt_blocks}"
)

print("\nSupported Nodes:")
logger.debug("\nSupported Nodes:")
for node_name in self.supported_operators:
print("-", node_name)
logger.debug("-", node_name)

if len(self.unsupported_operators) != 0:
print("\nUnsupported Nodes:")
logger.debug("\nUnsupported or Excluded Nodes:")
for node_name in self.unsupported_operators:
print("-", node_name)
print("\n")
logger.debug("-", node_name)
logger.debug("\n")
else:
print("\nAll Nodes Supported\n")
logger.debug("\nAll Nodes Supported\n")


def partition(
gm: torch.fx.GraphModule,
verbose: bool = True,
max_num_trt_engines: int = MAX_NUM_TRT_ENGINES,
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Sequence[str] = set(),
) -> torch.fx.GraphModule:
"""Partition an FX GraphModule with aten ops into TRT engines
Partitioning is based on converter operator support

Args:
gm: FX GraphModule to partition
verbose: Bool representing whether to print operator support
max_num_trt_engines: Maximum number of allowed TRT engines in partitioning
min_block_size: Minimum number of operators per TRT-Engine Block
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
Returns:
torch.fx.GraphModule
"""
supported_ops = TorchTensorRTOperatorSupport()
partitioner = CapabilityBasedPartitioner(gm, supported_ops)
supported_ops = TorchTensorRTOperatorSupport(torch_executed_ops=torch_executed_ops)
partitioner = TRTPartitioner(gm, supported_ops, min_block_size=min_block_size)

# Determine partitions, and raise error if the degree of partitioning
# exceeds a specified threshold
# Determine partitions based on user specifications and operator support
# Then, fuse partitions and display overview of supported/unsupported operators
partitions = partitioner.propose_partitions()
num_blocks = len(partitions)
if num_blocks > max_num_trt_engines:
raise AssertionError(
f"The graph module has {num_blocks} TRT Engines which is larger than the "
+ f"threshold={max_num_trt_engines}. Falling back to non-TRT module."
)

# Fuse partitions and display overview of supported/unsupported operators
fused_graph = partitioner.fuse_partitions(partitions)
num_blocks = len(partitions)

if verbose:
supported_ops.print_support_overview(num_blocks)
supported_ops.print_support_overview(len(partitions))

return fused_graph

Expand Down
Loading