Skip to content

feat: Add maxpool lowering passes and experimental folder in Dynamo #2358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ repos:
rev: 'v1.4.1'
hooks:
- id: mypy
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^tools|^docs|noxfile.py|setup.py|versions.py"
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py"
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.278
Expand Down
4 changes: 0 additions & 4 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
get_decompositions,
repair_input_aliasing,
)
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
from torch_tensorrt.dynamo.utils import (
parse_dynamo_kwargs,
prepare_inputs,
Expand Down Expand Up @@ -68,9 +67,6 @@ def _pretraced_backend(
try:
logger.debug("Pre-AOT Autograd graph:\n" + str(gm.graph))

# Perform Pre-AOT Lowering for Module-Level Replacement
gm = pre_aot_substitutions(gm)

fake_mode = detect_fake_mode(sample_inputs)

# Place backend tracing within FakeTensor context allowing nonfake Tensors
Expand Down
3 changes: 0 additions & 3 deletions py/torch_tensorrt/dynamo/lowering/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from ._decompositions import get_decompositions # noqa: F401
from ._fusers import * # noqa: F401
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
from ._pre_aot_lowering import register_substitution # noqa: F401
from ._repair_input_aliasing import repair_input_aliasing
from .passes import apply_lowering_passes
from .substitutions import * # noqa: F401
145 changes: 0 additions & 145 deletions py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .pass_manager import DynamoPassManager
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output
from .replace_max_pool_with_indices import replace_max_pool_with_indices

ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
[
Expand All @@ -17,6 +18,7 @@
repair_input_as_output,
lower_efficient_attention,
fuse_prims_broadcast,
replace_max_pool_with_indices,
]
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import logging
import operator
from typing import Sequence

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)


def replace_max_pool_with_indices(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""Replace MaxPool nodes which return unused indices"""
replacement_dict = {
torch.ops.aten.max_pool1d_with_indices.default: torch.ops.aten.max_pool1d.default,
torch.ops.aten.max_pool2d_with_indices.default: torch.ops.aten.max_pool2d.default,
torch.ops.aten.max_pool3d_with_indices.default: torch.ops.aten.max_pool3d.default,
}

modified_graph = False

for node in gm.graph.nodes:
# If the node is a placeholder and its only user is a clone node
# it was modified by the input alias-fixing pass, and the change
# needs to be undone
if (
node.target in replacement_dict
and len(node.users) == 1
and list(node.users)[0].target == operator.getitem
and list(node.users)[0].args[1] == 0
):
modified_graph = True

# Replace all uses of the clone with the placholder, delete the clone
getitem_node = list(node.users)[0]

with gm.graph.inserting_after(getitem_node):
maxpool_fused = gm.graph.call_function(
replacement_dict[node.target],
args=node.args,
kwargs=node.kwargs,
)

logger.debug(
f"Replacing all uses of nodes {node}, {getitem_node} with fused maxpool node {maxpool_fused} "
f"is the only user of placeholder {node} and was inserted by the compiler."
)

getitem_node.replace_all_uses_with(maxpool_fused)
gm.graph.erase_node(getitem_node)
gm.graph.erase_node(node)

if modified_graph:
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after fusing maxpool operators with indices:\n{gm.graph}")

return gm
2 changes: 0 additions & 2 deletions py/torch_tensorrt/dynamo/lowering/substitutions/__init__.py

This file was deleted.

76 changes: 0 additions & 76 deletions py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py

This file was deleted.

Loading