Skip to content

Commit fa320e2

Browse files
committed
feat: Add ATen lowering pass system
- Add documentation, testing, and lowering pass management systems for ATen lowering passes
1 parent e0a7525 commit fa320e2

File tree

11 files changed

+306
-50
lines changed

11 files changed

+306
-50
lines changed

docsrc/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ Tutorials
7373
tutorials/_rendered_examples/dynamo/torch_compile_resnet_example
7474
tutorials/_rendered_examples/dynamo/torch_compile_transformers_example
7575
tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage
76+
tutorials/_rendered_examples/dynamo/dynamo_aten_lowering_passes
7677

7778
Python API Documenation
7879
------------------------

examples/dynamo/README.rst

+1
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ a number of ways you can leverage this backend to accelerate inference.
99
* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile``
1010
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
1111
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
12+
:ref:`dynamo_aten_lowering_passes`: Custom modifications of a graph of ATen operators via lowering passes
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""
2+
.. _dynamo_aten_lowering_passes:
3+
4+
Dynamo ATen Lowering Passes
5+
======================================================
6+
7+
This interactive script is intended as an overview of the process by which ATen lowering passes are written and used."""
8+
9+
# %%
10+
# 1. Lowering Pass Function
11+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
12+
#
13+
# An ATen lowering pass function in Torch-TRT must satisfy two requirements:
14+
# - The function must take as input a single `torch.fx.GraphModule` and return the lowered
15+
# `torch.fx.GraphModule`
16+
# - The function must leave the graph in a valid and invoke-able state, including performing any
17+
# necessary linting and recompilation
18+
#
19+
# See below for an example of a lowering pass which repairs graphs that have inputs which are
20+
# also outputs, a disallowed configuration for TRT Engines.
21+
22+
# %%
23+
import logging
24+
25+
import torch
26+
27+
logger = logging.getLogger(__name__)
28+
29+
30+
def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
31+
"""Repair scenarios where inputs are also outputs of the graph
32+
33+
TRT does not allow such cases, so we insert a clone (identity) layer
34+
"""
35+
modified_graph = False
36+
37+
# Extract graph placeholder Tensors
38+
placeholders = [
39+
node
40+
for node in gm.graph.nodes
41+
if (
42+
node.op == "placeholder"
43+
and isinstance(node.type, type)
44+
and issubclass(node.type, torch.Tensor)
45+
)
46+
]
47+
48+
for placeholder in placeholders:
49+
# If any placeholder has any users which are direct graph outputs
50+
if len(placeholder.users) >= 1 and any(
51+
user.op == "output" for user in placeholder.users
52+
):
53+
modified_graph = True
54+
55+
# Get direct graph outputs which are direct uses of placeholders
56+
direct_outputs = [user for user in placeholder.users if user.op == "output"]
57+
58+
# Insert clone node for placeholder to ensure placeholder is not a direct output
59+
with gm.graph.inserting_after(placeholder):
60+
cloned_placeholder = gm.graph.call_function(
61+
torch.ops.aten.clone.default,
62+
args=(placeholder,),
63+
)
64+
65+
# Replace placeholder as output with cloned version
66+
for output in direct_outputs:
67+
output.replace_input_with(placeholder, cloned_placeholder)
68+
69+
# If the graph was modified, clean up the graph and ensure it is up-to-date
70+
if modified_graph:
71+
gm.graph.eliminate_dead_code()
72+
gm.graph.lint()
73+
gm.recompile()
74+
logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}")
75+
76+
return gm
77+
78+
79+
# %%
80+
# 2. Lowering Pass Registration
81+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
82+
#
83+
# To add a lowering pass, use the convenience function `add_lowering_pass` in the module
84+
# `torch_tensorrt.dynamo.lowering.passes`. See below for an example:
85+
86+
# %%
87+
from torch_tensorrt.dynamo.lowering.passes import add_lowering_pass
88+
89+
add_lowering_pass(repair_input_as_output)
90+
91+
# %%
92+
# 3. Apply Available Lowering Passes
93+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
94+
#
95+
# To apply all lowering passes to a graph, the convenience function `apply_lowering_passes` in the module
96+
# `torch_tensorrt.dynamo.lowering.passes` can be used. This function is automatically invoked in the Torch-TRT Dynamo
97+
# paths. Additionally, the graph after each modifying pass is logged in the debug logs for Torch-TRT runs.

py/torch_tensorrt/dynamo/backend/backends.py

+4-45
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,15 @@
77
import torch
88
import torch._dynamo as td
99
import torch.utils._pytree as pytree
10-
import torch_tensorrt
1110
from torch._dynamo.utils import detect_fake_mode
1211
from torch._functorch.aot_autograd import _aot_export_function
1312
from torch._ops import OpOverload
1413
from torch_tensorrt.dynamo import CompilationSettings
1514
from torch_tensorrt.dynamo.compile import compile_module
16-
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
15+
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
1716
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1817
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
1918

20-
from packaging import version
21-
22-
# Modify import location of utilities based on Torch version
23-
if version.parse(torch_tensorrt.sanitized_torch_version()) <= version.parse("2.1.0"):
24-
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
25-
else:
26-
from torch._inductor.constant_folding import (
27-
ConstantFolder,
28-
replace_node_with_constant,
29-
)
30-
3119
logger = logging.getLogger(__name__)
3220

3321

@@ -86,7 +74,7 @@ def _pretraced_backend(
8674
fake_mode, "allow_non_fake_inputs", True
8775
), fake_mode:
8876
# Invoke AOTAutograd to translate operators to aten
89-
graph_module = aot_export_for_compile(
77+
gm = aot_export_for_compile(
9078
gm,
9179
sample_inputs,
9280
decompositions=get_decompositions(
@@ -96,10 +84,10 @@ def _pretraced_backend(
9684

9785
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
9886

99-
constant_fold(graph_module)
87+
gm = apply_lowering_passes(gm)
10088

10189
trt_compiled = compile_module(
102-
graph_module,
90+
gm,
10391
sample_inputs,
10492
settings=settings,
10593
)
@@ -123,35 +111,6 @@ def _pretraced_backend(
123111
raise
124112

125113

126-
@torch.utils._python_dispatch._disable_current_modes() # type: ignore
127-
def constant_fold(gm: torch.fx.GraphModule) -> Any:
128-
"""Adapted from:
129-
https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
130-
131-
Folds constants in the graph module, not skipping constructors
132-
133-
Modifies the graph in-place and replaces node with constants
134-
"""
135-
cf = ConstantFolder(gm, skip_constructors=False)
136-
cf.run()
137-
138-
for node, constant in cf.node_replacements.items():
139-
replace_node_with_constant(gm, node, constant)
140-
141-
erased_params = []
142-
for node in gm.graph.nodes:
143-
if node.op == "get_attr" and len(node.users) == 0:
144-
delattr(gm, node.target)
145-
erased_params.append(node)
146-
147-
for node in erased_params:
148-
gm.graph.erase_node(node)
149-
150-
gm.graph.eliminate_dead_code()
151-
gm.graph.lint()
152-
gm.recompile()
153-
154-
155114
def aot_export_for_compile(
156115
func: torch.fx.GraphModule,
157116
args: Sequence[torch.Tensor],

py/torch_tensorrt/dynamo/lowering/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
from ._fusers import * # noqa: F401
33
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
44
from ._pre_aot_lowering import register_substitution # noqa: F401
5+
from .passes import add_lowering_pass, apply_lowering_passes
56
from .substitutions import * # noqa: F401
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Callable
2+
3+
import torch
4+
from torch.fx.passes.pass_manager import PassManager
5+
6+
from .constant_folding import constant_fold
7+
from .repair_input_as_output import repair_input_as_output
8+
9+
ATEN_LOWERING_PASSES = PassManager.build_from_passlist(
10+
[
11+
constant_fold,
12+
repair_input_as_output,
13+
]
14+
)
15+
16+
17+
def add_lowering_pass(
18+
lowering_pass: Callable[[torch.fx.GraphModule], torch.fx.GraphModule]
19+
) -> None:
20+
"""Adds a lowering pass to the registry"""
21+
ATEN_LOWERING_PASSES.add_pass(lowering_pass)
22+
return
23+
24+
25+
def apply_lowering_passes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
26+
"""Applies the lowering passes to a graph module, returns the modified GraphModule"""
27+
return ATEN_LOWERING_PASSES(gm)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import logging
2+
3+
import torch
4+
import torch_tensorrt
5+
6+
from packaging import version
7+
8+
# Modify import location of utilities based on Torch version
9+
if version.parse(torch_tensorrt.sanitized_torch_version()) <= version.parse("2.1.0"):
10+
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
11+
else:
12+
from torch._inductor.constant_folding import (
13+
ConstantFolder,
14+
replace_node_with_constant,
15+
)
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
@torch.utils._python_dispatch._disable_current_modes() # type: ignore
21+
def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
22+
"""Adapted from:
23+
https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
24+
25+
Folds constants in the graph module, not skipping constructors
26+
27+
Modifies the graph in-place and replaces node with constants
28+
"""
29+
cf = ConstantFolder(gm, skip_constructors=False)
30+
cf.run()
31+
32+
for node, constant in cf.node_replacements.items():
33+
replace_node_with_constant(gm, node, constant)
34+
35+
erased_params = []
36+
for node in gm.graph.nodes:
37+
# If get_attr node has no users, mark it for deletion
38+
if node.op == "get_attr" and len(node.users) == 0:
39+
# If the node's parameter is not a parameter of any other node, remove it
40+
if not any(
41+
other.target == node.target for other in gm.graph.nodes if other != node
42+
):
43+
delattr(gm, node.target)
44+
erased_params.append(node)
45+
46+
# Remove unused nodes from the graph
47+
for node in erased_params:
48+
gm.graph.erase_node(node)
49+
50+
gm.graph.eliminate_dead_code()
51+
gm.graph.lint()
52+
gm.recompile()
53+
54+
logger.debug(f"Graph after constant folding:\n{gm.graph}")
55+
56+
return gm
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import logging
2+
3+
import torch
4+
5+
logger = logging.getLogger(__name__)
6+
7+
8+
def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
9+
"""Repair scenarios where inputs are also outputs of the graph
10+
11+
TRT does not allow such cases, so we insert a clone (identity) layer
12+
"""
13+
modified_graph = False
14+
15+
# Extract graph placeholder Tensors
16+
placeholders = [
17+
node
18+
for node in gm.graph.nodes
19+
if (
20+
node.op == "placeholder"
21+
and isinstance(node.type, type)
22+
and issubclass(node.type, torch.Tensor)
23+
)
24+
]
25+
26+
for placeholder in placeholders:
27+
# If any placeholder has any users which are direct graph outputs
28+
if len(placeholder.users) >= 1 and any(
29+
user.op == "output" for user in placeholder.users
30+
):
31+
modified_graph = True
32+
33+
# Get direct graph outputs which are direct uses of placeholders
34+
direct_outputs = [user for user in placeholder.users if user.op == "output"]
35+
36+
# Insert clone node for placeholder to ensure placeholder is not a direct output
37+
with gm.graph.inserting_after(placeholder):
38+
cloned_placeholder = gm.graph.call_function(
39+
torch.ops.aten.clone.default,
40+
args=(placeholder,),
41+
)
42+
43+
# Replace placeholder as output with cloned version
44+
for output in direct_outputs:
45+
output.replace_input_with(placeholder, cloned_placeholder)
46+
47+
if modified_graph:
48+
gm.graph.eliminate_dead_code()
49+
gm.graph.lint()
50+
gm.recompile()
51+
logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}")
52+
53+
return gm

setup.py

+2
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ def run(self):
392392
"torch_tensorrt.dynamo.conversion.impl.unary",
393393
"torch_tensorrt.dynamo.lowering",
394394
"torch_tensorrt.dynamo.lowering.substitutions",
395+
"torch_tensorrt.dynamo.lowering.passes",
395396
"torch_tensorrt.dynamo.partitioning",
396397
"torch_tensorrt.dynamo.runtime",
397398
"torch_tensorrt.dynamo.tools",
@@ -419,6 +420,7 @@ def run(self):
419420
"torch_tensorrt.dynamo.conversion.impl.unary": "py/torch_tensorrt/dynamo/conversion/impl/unary",
420421
"torch_tensorrt.dynamo.lowering": "py/torch_tensorrt/dynamo/lowering",
421422
"torch_tensorrt.dynamo.lowering.substitutions": "py/torch_tensorrt/dynamo/lowering/substitutions",
423+
"torch_tensorrt.dynamo.lowering.passes": "py/torch_tensorrt/dynamo/lowering/passes",
422424
"torch_tensorrt.dynamo.partitioning": "py/torch_tensorrt/dynamo/partitioning",
423425
"torch_tensorrt.dynamo.runtime": "py/torch_tensorrt/dynamo/runtime",
424426
"torch_tensorrt.dynamo.tools": "py/torch_tensorrt/dynamo/tools",

0 commit comments

Comments
 (0)