Skip to content

feat: Add support for require_full_compilation in Dynamo #2138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
USE_PYTHON_RUNTIME = False
USE_FAST_PARTITIONER = True
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
REQUIRE_FULL_COMPILATION = False


def default_device() -> Device:
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
OPTIMIZATION_LEVEL,
PASS_THROUGH_BUILD_FAILURES,
PRECISION,
REQUIRE_FULL_COMPILATION,
TRUNCATE_LONG_AND_DOUBLE,
USE_FAST_PARTITIONER,
USE_PYTHON_RUNTIME,
Expand Down Expand Up @@ -57,3 +58,4 @@ class CompilationSettings:
use_fast_partitioner: bool = USE_FAST_PARTITIONER
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
device: Device = field(default_factory=default_device)
require_full_compilation: bool = REQUIRE_FULL_COMPILATION
10 changes: 7 additions & 3 deletions py/torch_tensorrt/dynamo/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
OPTIMIZATION_LEVEL,
PASS_THROUGH_BUILD_FAILURES,
PRECISION,
REQUIRE_FULL_COMPILATION,
TRUNCATE_LONG_AND_DOUBLE,
USE_FAST_PARTITIONER,
USE_PYTHON_RUNTIME,
Expand Down Expand Up @@ -57,7 +58,7 @@ def compile(
dla_global_dram_size: int = 536870912,
calibrator: object = None,
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
require_full_compilation: bool = False,
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Optional[List[str]] = None,
torch_executed_modules: Optional[List[str]] = None,
Expand All @@ -80,8 +81,10 @@ def compile(
"The Dynamo backend is an experimental feature, for which only the "
"following arguments are supported: "
"{enabled_precisions, debug, workspace_size, min_block_size, "
"torch_executed_ops, pass_through_build_failures, use_fast_partitioner, "
"enable_experimental_decompositions}"
"max_aux_streams, version_compatible, optimization_level, "
"torch_executed_ops, pass_through_build_failures, "
"use_fast_partitioner, enable_experimental_decompositions, "
"require_full_compilation}"
)

if not isinstance(inputs, collections.abc.Sequence):
Expand Down Expand Up @@ -126,6 +129,7 @@ def compile(
"truncate_long_and_double": truncate_long_and_double,
"use_fast_partitioner": use_fast_partitioner,
"enable_experimental_decompositions": enable_experimental_decompositions,
"require_full_compilation": require_full_compilation,
}

settings = CompilationSettings(**compilation_options)
Expand Down
55 changes: 47 additions & 8 deletions py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
_SplitterSettingBase,
)
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS, NodeSet
from torch_tensorrt.dynamo._defaults import DEBUG, MIN_BLOCK_SIZE
from torch_tensorrt.dynamo._defaults import (
DEBUG,
MIN_BLOCK_SIZE,
REQUIRE_FULL_COMPILATION,
)
from torch_tensorrt.dynamo.conversion.converter_registry import (
DYNAMO_CONVERTERS as CONVERTERS,
)
Expand Down Expand Up @@ -92,6 +96,7 @@ class TRTPartitioner(_SplitterBase): # type: ignore
allowed_single_node_partition_ops: Nodes which can be included in single-node partitons.
Generally useful for module-level exclusion ops which are intensive despite being single functions
min_block_size: Minimum number of computational operators per block
require_full_compilation: Require that all computational operators be run in TRT
Returns:
torch.fx.GraphModule
"""
Expand All @@ -104,6 +109,7 @@ def __init__(
Collection[str]
] = DEFAULT_SINGLE_NODE_PARTITIONS,
min_block_size: int = MIN_BLOCK_SIZE,
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
):
"""
Preprocesses graph before splitting:
Expand Down Expand Up @@ -142,6 +148,7 @@ def __init__(

self.num_trt_accelerated_subgraphs: Optional[int] = None
self.allowed_single_node_partition_ops = allowed_single_node_partition_ops
self.require_full_compilation = require_full_compilation

def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]:
"""
Expand All @@ -151,12 +158,16 @@ def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph
result: List[Subgraph] = []
for subgraph in subgraphs:
if subgraph.is_acc:
if len(subgraph.nodes) >= self.settings.min_acc_module_size or (
self.allowed_single_node_partition_ops is not None
and any(
ConverterRegistry.qualified_name_or_str(node.target)
in self.allowed_single_node_partition_ops
for node in subgraph.nodes
if (
len(subgraph.nodes) >= self.settings.min_acc_module_size
or self.require_full_compilation
or (
self.allowed_single_node_partition_ops is not None
and any(
ConverterRegistry.qualified_name_or_str(node.target)
in self.allowed_single_node_partition_ops
for node in subgraph.nodes
)
)
):
result.append(subgraph)
Expand Down Expand Up @@ -185,6 +196,27 @@ def partition_graph(self) -> torch.fx.GraphModule:
# Delegate nodes based on operator coverage
subgraphs = self.put_nodes_into_subgraphs()

# A graph is fully supported if there is a single partition and all operators are supported/convertible
full_support = len([s for s in subgraphs if s.is_acc]) == 1 and not getattr(
self.operator_support, "unsupported_operators", True
)

if not full_support and self.require_full_compilation:
raise AssertionError(
"require_full_compilation=True was specified, but model is not fully supported"
)

if (
full_support
and self.require_full_compilation
and self.settings.min_acc_module_size != MIN_BLOCK_SIZE
):
logger.warning(
"Detected both require_full_compilation and min_block_size compilation "
"arguments were specified. Disregarding min_block_size argument for "
"fully supported model."
)

# Remove segments smaller than the block size (with exceptions)
subgraphs = self.remove_small_acc_subgraphs(subgraphs)

Expand Down Expand Up @@ -217,6 +249,7 @@ def partition(
verbose: bool = DEBUG,
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Collection[Target] = set(),
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
) -> torch.fx.GraphModule:
"""Partition an FX GraphModule with aten ops into TRT engines
Partitioning is based on converter operator support
Expand All @@ -226,6 +259,7 @@ def partition(
verbose: Bool representing whether to print operator support
min_block_size: Minimum number of operators per TRT-Engine Block
torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
require_full_compilation: Require that all computational operators be run in TRT
Returns:
torch.fx.GraphModule
"""
Expand All @@ -236,7 +270,12 @@ def partition(

# Construct
supported_ops = OpSupportTester(torch_executed_ops=torch_executed_ops)
partitioner = TRTPartitioner(gm, supported_ops, min_block_size=min_block_size)
partitioner = TRTPartitioner(
gm,
supported_ops,
min_block_size=min_block_size,
require_full_compilation=require_full_compilation,
)

partitioned_graph = partitioner.partition_graph()

Expand Down
45 changes: 42 additions & 3 deletions py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from torch.fx.graph_module import GraphModule
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
from torch.fx.passes.operator_support import OperatorSupport, SupportDict
from torch_tensorrt.dynamo._defaults import DEBUG, MIN_BLOCK_SIZE
from torch_tensorrt.dynamo._defaults import (
DEBUG,
MIN_BLOCK_SIZE,
REQUIRE_FULL_COMPILATION,
)
from torch_tensorrt.dynamo.conversion.converter_registry import (
DYNAMO_CONVERTERS as CONVERTERS,
)
Expand All @@ -26,6 +30,7 @@ class TRTPartitioner(CapabilityBasedPartitioner): # type: ignore[misc]
allowed_single_node_partition_ops: Nodes which can be included in single-node partitons.
Generally useful for module-level exclusion ops which are intensive despite being single functions
min_block_size: Minimum number of computational operators per block
require_full_compilation: Require that all computational operators be run in TRT
Returns:
torch.fx.GraphModule
"""
Expand All @@ -40,6 +45,7 @@ def __init__(
Collection[str]
] = DEFAULT_SINGLE_NODE_PARTITIONS,
min_block_size: int = MIN_BLOCK_SIZE,
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
) -> None:
super().__init__(
graph_module,
Expand All @@ -50,12 +56,34 @@ def __init__(
)

self.min_block_size = min_block_size
self.require_full_compilation = require_full_compilation

def propose_partitions(self) -> List[Partition]:
# Propose partitions using the default, then refine the results
initial_proposed_partitions = super().propose_partitions()
partitions = dict(enumerate(initial_proposed_partitions))

# A graph is fully supported if there is a single partition and all operators are supported/convertible
full_support = len(partitions) == 1 and not getattr(
self.operator_support, "unsupported_operators", True
)

if not full_support and self.require_full_compilation:
raise AssertionError(
"require_full_compilation=True was specified, but model is not fully supported"
)

if (
full_support
and self.require_full_compilation
and self.min_block_size != MIN_BLOCK_SIZE
):
logger.warning(
"Detected both require_full_compilation and min_block_size compilation "
"arguments were specified. Disregarding min_block_size argument for "
"fully supported model."
)

# For each partition, determine whether or not the number of computational operators
# exceeds the threshold, and if not, remove that partition
partitions_to_remove = {}
Expand All @@ -81,7 +109,11 @@ def propose_partitions(self) -> List[Partition]:
):
compute_node_count += 1

if compute_node_count < self.min_block_size and not exempted_partition:
if (
compute_node_count < self.min_block_size
and not exempted_partition
and not (full_support and self.require_full_compilation)
):
partitions_to_remove[id] = compute_node_count

# Remove any nodes violating the criteria specified by the user
Expand Down Expand Up @@ -172,6 +204,7 @@ def partition(
verbose: bool = DEBUG,
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Optional[Set[str]] = None,
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
) -> torch.fx.GraphModule:
"""Partition an FX GraphModule with aten ops into TRT engines
Partitioning is based on converter operator support
Expand All @@ -181,6 +214,7 @@ def partition(
verbose: Bool representing whether to print operator support
min_block_size: Minimum number of operators per TRT-Engine Block
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
require_full_compilation: Whether to require that all operators be run in TRT
Returns:
torch.fx.GraphModule
"""
Expand All @@ -189,7 +223,12 @@ def partition(
if torch_executed_ops is not None
else set()
)
partitioner = TRTPartitioner(gm, supported_ops, min_block_size=min_block_size)
partitioner = TRTPartitioner(
gm,
supported_ops,
min_block_size=min_block_size,
require_full_compilation=require_full_compilation,
)

# Determine partitions based on user specifications and operator support
# Then, fuse partitions and display overview of supported/unsupported operators
Expand Down
10 changes: 9 additions & 1 deletion py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,15 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
"If this is incorrect, please specify an input device, via the device keyword."
)

logger.info(f"Compiling with Settings:\n{settings}")
# Ignore and warn about require_full_compilation flag
if settings.require_full_compilation:
logger.warning(
"Detected require_full_compilation=True for a torch.compile run. "
"This option has no effect in torch.compile."
)
settings.require_full_compilation = False

logger.info("Compilation Settings: %s\n", settings)

return settings

Expand Down
24 changes: 24 additions & 0 deletions tests/py/dynamo/partitioning/test_fast_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,30 @@ def forward(self, x, y):
"Single operators should not be segmented",
)

def test_partition_fully_supported_one_op_require_full_compilation(self):
class FullySupportedOneOp(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, y):
return torch.ops.aten.add.Tensor(x, y)

fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
partitioned_graph = partitioning.fast_partition(
deepcopy(fx_graph), require_full_compilation=True
)
self.assertEquals(
len(
[
1
for submod in list(partitioned_graph.named_children())
if "_run_on_acc" in submod[0]
]
),
1,
"Single operators can be segmented if full compilation is required",
)

def test_partition_fully_supported_multi_op(self):
class FullySupportedMultiOp(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
Expand Down
18 changes: 18 additions & 0 deletions tests/py/dynamo/partitioning/test_global_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,24 @@ def forward(self, x, y):
"Single operators should not be segmented",
)

def test_partition_fully_supported_one_op_require_full_compilation(self):
class FullySupportedOneOp(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, y):
return torch.ops.aten.add.Tensor(x, y)

fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
partitioned_graph = partitioning.global_partition(
deepcopy(fx_graph), require_full_compilation=True
)
self.assertEquals(
len(list(partitioned_graph.named_children())),
1,
"Single operators can be segmented if full compilation is required",
)

def test_partition_fully_supported_multi_op(self):
class FullySupportedMultiOp(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
Expand Down