Skip to content

feat: Add support for general-purpose function acceleration in Dynamo [6 / x] #1980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
get_decompositions,
)
from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import (
pre_aot_module_replacement,
pre_aot_substitutions,
)
from torch_tensorrt.dynamo.backend.lowering._partition import (
partition,
Expand Down Expand Up @@ -49,7 +49,7 @@ def aot_torch_tensorrt_aten_backend(
)

# Perform Pre-AOT Lowering for Module-Level Replacement
gm = pre_aot_module_replacement(gm)
gm = pre_aot_substitutions(gm)

# Invoke AOTAutograd to translate operators to aten
return aot_module_simplified(
Expand Down
6 changes: 3 additions & 3 deletions py/torch_tensorrt/dynamo/backend/lowering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
get_decompositions,
)
from ._pre_aot_lowering import (
MODULE_SUBSTITUTION_REGISTRY,
module_substitution,
SUBSTITUTION_REGISTRY,
register_substitution,
)
from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS
from .module_substitutions import *
from .substitutions import *
6 changes: 3 additions & 3 deletions py/torch_tensorrt/dynamo/backend/lowering/_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE
from torch_tensorrt.dynamo.backend.lowering import MODULE_SUBSTITUTION_REGISTRY
from torch_tensorrt.dynamo.backend.lowering import SUBSTITUTION_REGISTRY
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
from torch.fx.graph_module import GraphModule
from torch.fx.node import _get_qualified_name
Expand All @@ -16,8 +16,8 @@
logger = logging.getLogger(__name__)

DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set(
_get_qualified_name(module.new_operator)
for module in MODULE_SUBSTITUTION_REGISTRY.values()
_get_qualified_name(to_replace.new_operator)
for to_replace in SUBSTITUTION_REGISTRY.values()
)


Expand Down
121 changes: 67 additions & 54 deletions py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, Type
from typing import Any, Callable, Dict, Optional, Type, Union
import torch
import logging

Expand All @@ -8,59 +8,62 @@


@dataclass(frozen=True)
class ModuleReplacement:
class Substitution:
"""Class to store key functionality for module replacement"""

# torch.ops.___ name for replacement function for module
new_operator: torch._ops.OpOverload

# Function taking a containing graph, a submodule, and a 'call_module' node and returning
# a replacement node, with type 'call_function', or raising an Error if incompatibility is detected
# Function taking a containing graph, a node, and optionally a submodule (if replacing a module)
# and returning a replacement node, with type 'call_function', or raising an Error if
# incompatibility is detected
# Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph
subgraph_insertion_fn: Callable[
[torch.fx.GraphModule, torch.nn.Module, torch.fx.Node], torch.fx.Node
[torch.fx.GraphModule, torch.fx.Node, Optional[torch.nn.Module]], torch.fx.Node
]


# Dictionary mapping module to ModuleReplacement instance
MODULE_SUBSTITUTION_REGISTRY: Dict[Type[torch.nn.Module], ModuleReplacement] = dict()
# Dictionary mapping module to Substitution instance
SUBSTITUTION_REGISTRY: Dict[
Union[Type[torch.nn.Module], Callable], Substitution
] = dict()


def module_substitution(
module_to_replace: Type[torch.nn.Module],
def register_substitution(
module_or_function_to_replace: Union[Type[torch.nn.Module], Callable],
new_operator: torch._ops.OpOverload,
enabled: bool = True,
) -> Callable[[Any], Any]:
"""Decorator to register subgraph insertion functions

Args:
module_to_replace: nn.Module to replace
module_or_function_to_replace: nn.Module or node target Callable to replace
new_operator: Custom torch operator to replace with
enabled: Whether the substitution is enabled or disabled
Returns:
torch.fx.GraphModule
"""

def register_substitution(subgraph_insertion_fn):
def enable_substitution(subgraph_insertion_fn):
"""Function for use if substitution is enabled"""
module_replacement = ModuleReplacement(
replacement = Substitution(
new_operator=new_operator, subgraph_insertion_fn=subgraph_insertion_fn
)
MODULE_SUBSTITUTION_REGISTRY[module_to_replace] = module_replacement
SUBSTITUTION_REGISTRY[module_or_function_to_replace] = replacement
return subgraph_insertion_fn

def disable_substitution(subgraph_insertion_fn):
"""Function for use if substitution is disabled"""
return subgraph_insertion_fn

return register_substitution if enabled else disable_substitution
return enable_substitution if enabled else disable_substitution


def pre_aot_module_replacement(gm: torch.fx.GraphModule):
"""Perform module-level graph replacement prior to AOT tracing
def pre_aot_substitutions(gm: torch.fx.GraphModule):
"""Perform graph substitutions prior to AOT tracing

Args:
gm: FX GraphModule to perform module replacement on
gm: FX GraphModule to perform substitution on
Returns:
torch.fx.GraphModule

Expand All @@ -73,48 +76,58 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule):

# Iterate over graph nodes, extracting module calls, to check for interceptions
for n in gm.graph.nodes:
exists_in_registry = False
to_replace = None

if n.op == "call_module":
# Extract submodule from graph
# Extract submodule from graph, validate in registry
submodule = gm.get_submodule(n.target)

# If submodule is a member of the substitution registry, replace it
if type(submodule) in MODULE_SUBSTITUTION_REGISTRY:

try:
replacement = MODULE_SUBSTITUTION_REGISTRY[type(submodule)]
op, insertion_fn = (
replacement.new_operator,
replacement.subgraph_insertion_fn,
)
logger.debug(
f"Replacing module of type {type(submodule)} with {op}"
to_replace = type(submodule)
exists_in_registry = to_replace in SUBSTITUTION_REGISTRY
elif n.op == "call_function":
# Extract function from graph, validate in registry
to_replace = n.target
exists_in_registry = n.target in SUBSTITUTION_REGISTRY

# If submodule/function is a member of the substitution registry, replace it
if exists_in_registry:
try:
replacement = SUBSTITUTION_REGISTRY[to_replace]
op, insertion_fn = (
replacement.new_operator,
replacement.subgraph_insertion_fn,
)
logger.debug(f"Replacing node of type {to_replace} with {op}")

# Insert new node prior to older node
with gm.graph.inserting_before(n):
new_node = insertion_fn(
gm, n, submodule if n.op == "call_module" else None
)

# Insert new node prior to older node
with gm.graph.inserting_before(n):
new_node = insertion_fn(gm, submodule, n)

# If submodule is not a native torch.nn module, it must be manually excluded
# from Dynamo tracing
if not type(submodule).__module__.startswith("torch.nn"):
torch._dynamo.allowed_functions._allowed_function_ids.add(
id(type(submodule))
)

# Replace all original node uses and clean up graph
n.replace_all_uses_with(new_node)
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()

# A module replacement can fail in the event that the specific instance of the submodule cannot
# be replaced
except Exception:
logger.debug(
f"Encountered error while replacing {type(submodule)}",
exc_info=True,
# If submodule is not a native torch.nn module, it must be manually excluded
# from Dynamo tracing
if n.op == "call_module" and not type(submodule).__module__.startswith(
"torch.nn"
):
torch._dynamo.allowed_functions._allowed_function_ids.add(
id(to_replace)
)
continue

# Replace all original node uses and clean up graph
n.replace_all_uses_with(new_node)
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()

# A replacement can fail in the event that the specific instance of the submodule/function
# cannot be replaced
except Exception:
logger.debug(
f"Encountered error while replacing {to_replace}",
exc_info=True,
)
continue

# Perform cleanup and recompilation before returning module
gm.graph.eliminate_dead_code()
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .maxpool1d import *
from .einsum import *
80 changes: 80 additions & 0 deletions py/torch_tensorrt/dynamo/backend/lowering/substitutions/einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Dict, Tuple
import torch
from torch._custom_op.impl import custom_op
from torch.fx.node import Argument, Target

from torch_tensorrt.fx.converter_registry import tensorrt_converter
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

from torch_tensorrt.dynamo.backend.lowering import register_substitution


@custom_op(
qualname="tensorrt::einsum",
manual_schema="(str equation, Tensor[] tensors) -> Tensor",
)
def einsum(equation, tensors):
# Defines operator schema, name, namespace, and function header
...


@einsum.impl("cpu")
@einsum.impl("cuda")
@einsum.impl_abstract()
def einsum_generic(
*args,
**kwargs,
):
# Defines a converter implementation for AOT Autograd to use for shape analysis/propagation
return torch.einsum(
*args,
**kwargs,
)


@tensorrt_converter(torch.ops.tensorrt.einsum.default)
def aten_ops_einsum(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> TRTTensor:
# Defines converter replacing the default operator for this function
for input_trt in args[1]:
if not isinstance(input_trt, TRTTensor):
raise RuntimeError(f"Einsum received non-TRTTensor input: {input_trt}")

einsum_layer = network.add_einsum(inputs=args[1], equation=args[0])

set_layer_name(einsum_layer, target, name)
return einsum_layer.get_output(0)


@register_substitution(torch.einsum, torch.ops.tensorrt.einsum)
def einsum_insertion_fn(
gm: torch.fx.GraphModule,
node: torch.fx.Node,
_unused: None = None,
) -> torch.fx.Node:
equation = node.args[0]

# Ensure inputs is a list of (Tensor) arguments
if isinstance(node.args[1], (tuple, list)):
inputs = node.args[1]
else:
inputs = node.args[1:]

assert (
1 <= len(inputs) <= 2
), f"TRT Einsum currently only supports 1 or 2 Tensors, got {len(inputs)} Tensors"

# Ensure the input is formatted as an equation and
new_node = gm.graph.call_function(
torch.ops.tensorrt.einsum,
args=(equation, inputs),
kwargs=node.kwargs,
)

return new_node
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch_tensorrt.fx.converters import acc_ops_converters
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

from torch_tensorrt.dynamo.backend.lowering import module_substitution
from torch_tensorrt.dynamo.backend.lowering import register_substitution


# This file serves as an example and a tutorial for excluding custom modules from
Expand Down Expand Up @@ -71,9 +71,11 @@ def maxpool1d_generic(
# "bias": bias,
# ...
#
@module_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d)
@register_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d)
def maxpool1d_insertion_fn(
gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node
gm: torch.fx.GraphModule,
node: torch.fx.Node,
submodule: torch.nn.Module,
) -> torch.fx.Node:
# Defines insertion function for new node
new_node = gm.graph.call_function(
Expand Down
46 changes: 46 additions & 0 deletions py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,51 @@ def forward(self, x):
)


class TestEinsum(TestCase):
def test_pre_aot_lowering_einsum(self):
class Einsum(torch.nn.Module):
def forward(self, x, y):
return torch.einsum("ij,ji->ij", x, y)

# Operations expected to be included in the traced graph after decompositions
expected_ops = {torch.ops.tensorrt.einsum.default}

inputs = [
torch.rand(
16,
16,
).cuda(),
torch.rand(
16,
16,
).cuda(),
]

fx_graph = torch.fx.symbolic_trace(Einsum())
_, expected_ops_unseen = lower_graph_testing(
fx_graph, inputs, expected_ops=expected_ops, min_block_size=1
)

self.assertEquals(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = compile(
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results))
self.assertAlmostEqual(
max_diff, 0, f"Einsum TRT outputs don't match with the original model."
)


if __name__ == "__main__":
run_tests()
Loading