Skip to content

feat: Add ATen lowering pass system #2280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions docsrc/contributors/writing_dynamo_aten_lowering_passes.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
.. _writing_dynamo_aten_lowering_passes:

Writing Dynamo ATen Lowering Passes
===================

Basics of a Lowering Pass
------------

ATen lowering passes are Python functions which take as input a graph of ATen operators, apply some desired modification such as operator coalescing/fusion, operator replacement, subgraph rewriting, custom operator insertion, or other operation on a `torch.fx.GraphModule`, then return the modified graph to the caller. These lowering passes generally modify the graph in-place and return the same input object.

Lowering Pass Requirements
------------

An ATen lowering pass function in Torch-TRT must satisfy two requirements:
- The function must take as input a single `torch.fx.GraphModule` and return the lowered `torch.fx.GraphModule`
- The function must leave the graph in a valid and invoke-able state, including performing any necessary linting and recompilation

See this link for information on `Graph Manipulations <https://pytorch.org/docs/stable/fx.html#graph-manipulation>`_ in FX. See below for an example of a lowering pass which repairs graphs that have inputs which are also outputs, a disallowed configuration for TRT Engines.

Example Lowering Pass
------------

.. code-block:: python

def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""Repair scenarios where inputs are also outputs of the graph

TRT does not allow such cases, so we insert a clone (identity) layer
"""
modified_graph = False

# Extract graph placeholder Tensors
placeholders = [
node
for node in gm.graph.nodes
if (
node.op == "placeholder"
and isinstance(node.type, type)
and issubclass(node.type, torch.Tensor)
)
]

for placeholder in placeholders:
# If any placeholder has any users which are direct graph outputs
if len(placeholder.users) >= 1 and any(
user.op == "output" for user in placeholder.users
):
modified_graph = True

# Get direct graph outputs which are direct uses of placeholders
direct_outputs = [user for user in placeholder.users if user.op == "output"]

# Insert clone node for placeholder to ensure
# placeholder is not a direct output
with gm.graph.inserting_after(placeholder):
cloned_placeholder = gm.graph.call_function(
torch.ops.aten.clone.default,
args=(placeholder,),
)

# Replace placeholder as output with cloned version
for output in direct_outputs:
output.replace_input_with(placeholder, cloned_placeholder)

# If the graph was modified, clean up the graph and ensure it is up-to-date
if modified_graph:
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}")

return gm


Registering Lowering Passes
----------------------

Lowering passes are currently registered in `py/torch_tensorrt/dynamo/lowering/passes/__init__.py`, using the `torch.fx.passes.pass_manager.PassManager` utility to assemble the list of passes in a desired order. New passes added directly to that list will be applied to graphs in the Torch-TensorRT `torch.compile` backend. Currently, we offer an ATen lowering pass registration decorator for convenience, which can be invoked either directly, or with the optional `index` keyword argument which controls where in the pass list the lowering pass will be inserted.

For instance, to insert the pass at the default location (end of the list), the following code can be used:

.. code-block:: python

@_aten_lowering_pass
def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
...

Alternatively, to insert the pass at a custom index (such as the front of the list) in the passlist, the following code can be used:

.. code-block:: python

@_aten_lowering_pass(index=0)
def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
...

There are also provided utilities in `torch_tensorrt.dynamo.lowering.passes` for displaying the currently-available lowering pass list, applying those passes to an arbitrary `torch.fx.GraphModule`, and removing the lowering pass at a specific index.

.. code-block:: python

# Print all lowering passes in the list
print(dump_lowering_passes())

# Apply lowering passes to a GraphModule
apply_lowering_passes(graph_module)

# Remove the lowering pass at index 1
_remove_lowering_pass(index=1)

**Note:** The above APIs are subject to change, as the lowering pass system evolves.
2 changes: 2 additions & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ Contributor Documentation
--------------------------------
* :ref:`system_overview`
* :ref:`writing_converters`
* :ref:`writing_dynamo_aten_lowering_passes`
* :ref:`useful_links`

.. toctree::
Expand All @@ -137,6 +138,7 @@ Contributor Documentation

contributors/system_overview
contributors/writing_converters
contributors/writing_dynamo_aten_lowering_passes
contributors/useful_links

Indices
Expand Down
5 changes: 2 additions & 3 deletions py/torch_tensorrt/dynamo/aten_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

import torch
from torch._export import export
from torch_tensorrt.dynamo.backend.backends import constant_fold
from torch_tensorrt.dynamo.lowering import get_decompositions
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
from torch_tensorrt.dynamo.utils import set_log_level

logger = logging.getLogger(__name__)
Expand All @@ -29,6 +28,6 @@ def trace(
"torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions)
):
graph_module = export(model, tuple(inputs)).module()
constant_fold(graph_module)
graph_module = apply_lowering_passes(graph_module)
logger.debug("Post export graph: " + str(graph_module.graph))
return graph_module
49 changes: 4 additions & 45 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,12 @@
from torch._dynamo.utils import detect_fake_mode
from torch._functorch.aot_autograd import _aot_export_function
from torch._ops import OpOverload
from torch_tensorrt._utils import sanitized_torch_version
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo.compile import compile_module
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs, set_log_level

from packaging import version

# Modify import location of utilities based on Torch version
if version.parse(sanitized_torch_version()) < version.parse("2.1.1"):
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
else:
from torch._inductor.constant_folding import (
ConstantFolder,
replace_node_with_constant,
)

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -84,7 +72,7 @@ def _pretraced_backend(
fake_mode, "allow_non_fake_inputs", True
), fake_mode:
# Invoke AOTAutograd to translate operators to aten
graph_module = aot_export_for_compile(
gm = aot_export_for_compile(
gm,
sample_inputs,
decompositions=get_decompositions(
Expand All @@ -94,10 +82,10 @@ def _pretraced_backend(

logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))

constant_fold(graph_module)
gm = apply_lowering_passes(gm)

trt_compiled = compile_module(
graph_module,
gm,
sample_inputs,
settings=settings,
)
Expand All @@ -121,35 +109,6 @@ def _pretraced_backend(
raise


@torch.utils._python_dispatch._disable_current_modes() # type: ignore
def constant_fold(gm: torch.fx.GraphModule) -> Any:
"""Adapted from:
https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197

Folds constants in the graph module, not skipping constructors

Modifies the graph in-place and replaces node with constants
"""
cf = ConstantFolder(gm, skip_constructors=False)
cf.run()

for node, constant in cf.node_replacements.items():
replace_node_with_constant(gm, node, constant)

erased_params = []
for node in gm.graph.nodes:
if node.op == "get_attr" and len(node.users) == 0:
delattr(gm, node.target)
erased_params.append(node)

for node in erased_params:
gm.graph.erase_node(node)

gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()


def aot_export_for_compile(
func: torch.fx.GraphModule,
args: Sequence[torch.Tensor],
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/lowering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from ._fusers import * # noqa: F401
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
from ._pre_aot_lowering import register_substitution # noqa: F401
from .passes import apply_lowering_passes
from .substitutions import * # noqa: F401
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._aten_lowering_pass import *
76 changes: 76 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import logging
from typing import Callable, Optional

import torch

from .constant_folding import constant_fold
from .pass_manager import DynamoPassManager
from .repair_input_as_output import repair_input_as_output

ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
[
constant_fold,
repair_input_as_output,
]
)

logger = logging.getLogger(__name__)


LoweringPassSignature = Callable[[torch.fx.GraphModule], torch.fx.GraphModule]


def _aten_lowering_pass(
*args: LoweringPassSignature,
index: Optional[int] = None,
) -> LoweringPassSignature:
"""Adds a lowering pass to the registry, at a specified index if desired

If no index is specified, the lowering pass is inserted at the end of the list
"""

def add_lowering_pass(
lowering_pass: LoweringPassSignature,
) -> LoweringPassSignature:
ATEN_LOWERING_PASSES.add_pass_with_index(lowering_pass, index)
logger.debug(
f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_LOWERING_PASSES}"
)
return lowering_pass

# If there are arguments specified, the decorator may have been called as-is
if args:
# The decorator may only be called with the lowering pass
# The index must be specified as a keyword argument
if len(args) == 1 and callable(args[0]):
return add_lowering_pass(args[0])
else:
raise AssertionError(
f"aten_lowering_pass decorator called with invalid arguments {args} "
"To specify an index to insert the pass, use the keyword 'index='"
)
# If no arguments are specified, the decorator was called with an index keyword
else:
return add_lowering_pass


def _remove_lowering_pass(*, index: int) -> None:
"""Removes a lowering pass at a specific index from the registry"""
ATEN_LOWERING_PASSES.remove_pass_with_index(index)
logger.debug(
f"Removed lowering pass at index {index}, current passlist: {ATEN_LOWERING_PASSES}"
)
return


def apply_lowering_passes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""Applies the lowering passes to a graph module, returns the modified GraphModule"""
logging.debug(
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_LOWERING_PASSES}"
)
return ATEN_LOWERING_PASSES(gm)


def dump_lowering_passes() -> str:
"""Returns a string containing the lowering passes"""
return str(ATEN_LOWERING_PASSES)
56 changes: 56 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import logging

import torch
from torch_tensorrt._utils import sanitized_torch_version

from packaging import version

# Modify import location of utilities based on Torch version
if version.parse(sanitized_torch_version()) < version.parse("2.1.1"):
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
else:
from torch._inductor.constant_folding import (
ConstantFolder,
replace_node_with_constant,
)

logger = logging.getLogger(__name__)


@torch.utils._python_dispatch._disable_current_modes() # type: ignore
def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""Adapted from:
https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197

Folds constants in the graph module, not skipping constructors

Modifies the graph in-place and replaces node with constants
"""
cf = ConstantFolder(gm, skip_constructors=False)
cf.run()

for node, constant in cf.node_replacements.items():
replace_node_with_constant(gm, node, constant)

erased_params = []
for node in gm.graph.nodes:
# If get_attr node has no users, mark it for deletion
if node.op == "get_attr" and len(node.users) == 0:
# If the node's parameter is not a parameter of any other node, remove it
if not any(
other.target == node.target for other in gm.graph.nodes if other != node
):
delattr(gm, node.target)
erased_params.append(node)

# Remove unused nodes from the graph
for node in erased_params:
gm.graph.erase_node(node)

gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()

logger.debug(f"Graph after constant folding:\n{gm.graph}")

return gm
Loading