Skip to content

fix: Remove input aliasing of builtin ops #2276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 11 additions & 56 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@

import logging
import unittest
from typing import Any, Callable, Dict, Optional, Sequence
from typing import Any, Callable, Sequence

import torch
import torch._dynamo as td
import torch.utils._pytree as pytree
from torch._dynamo.utils import detect_fake_mode
from torch._functorch.aot_autograd import _aot_export_function
from torch._ops import OpOverload
from torch._functorch.aot_autograd import aot_export_joint_simple
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo.compile import compile_module
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
from torch_tensorrt.dynamo.lowering import (
apply_lowering_passes,
get_decompositions,
repair_input_aliasing,
)
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs, set_log_level

Expand Down Expand Up @@ -71,10 +73,13 @@ def _pretraced_backend(
with unittest.mock.patch.object(
fake_mode, "allow_non_fake_inputs", True
), fake_mode:
repair_input_aliasing(gm)

# Invoke AOTAutograd to translate operators to aten
gm = aot_export_for_compile(
gm = aot_export_joint_simple(
gm,
sample_inputs,
trace_joint=False,
decompositions=get_decompositions(
settings.enable_experimental_decompositions
),
Expand Down Expand Up @@ -107,53 +112,3 @@ def _pretraced_backend(
+ "specify pass_through_build_failures=False."
)
raise


def aot_export_for_compile(
func: torch.fx.GraphModule,
args: Sequence[torch.Tensor],
*,
decompositions: Optional[Dict[OpOverload, Callable[[Any], Any]]] = None,
) -> torch.fx.GraphModule:
"""Adapted from:
https://github.com/pytorch/pytorch/blob/1a5fdc2458b98697c75c32eb6f4b8b34d76429cf/torch/_functorch/aot_autograd.py#L4084-L4158

Removed check for input aliasing in resultant subgraph - TRT is functional-only

Exports the function to ATen for torch compile
"""
# Trace function with input arguments and decompositions
with torch.no_grad():
fx_g, metadata, in_spec, out_spec = _aot_export_function(
func,
args,
decompositions=decompositions,
)

# No input mutations
if (
len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata])
!= 0
):
raise RuntimeError(
f"aot_export_joint_simple does not support input mutations. {str(metadata)}"
)
# No pytrees
if type(in_spec) == pytree.LeafSpec:
raise RuntimeError(
f"aot_export_for_compile requires inputs to be a single list/tuple. in_spec={str(in_spec)}"
)
if len([x for x in in_spec.children_specs if type(x) != pytree.LeafSpec]) != 0:
raise RuntimeError(
f"aot_export_for_compile requires individual inputs not to be pytrees. in_spec={str(in_spec)}"
)
if type(out_spec) == pytree.LeafSpec:
raise RuntimeError(
f"aot_export_for_compile requires outputs to be a single list/tuple. out_spec={str(out_spec)}"
)
if len([x for x in out_spec.children_specs if type(x) != pytree.LeafSpec]) != 0:
raise RuntimeError(
f"aot_export_for_compile requires individual outputs not to be pytrees. out_spec={str(out_spec)}"
)

return fx_g
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/lowering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from ._fusers import * # noqa: F401
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
from ._pre_aot_lowering import register_substitution # noqa: F401
from ._repair_input_aliasing import repair_input_aliasing
from .passes import apply_lowering_passes
from .substitutions import * # noqa: F401
38 changes: 38 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_repair_input_aliasing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import get_tensor_placeholders

logger = logging.getLogger(__name__)


def repair_input_aliasing(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""Inserts clone operators temporarily ahead of every placeholder

See: https://github.com/pytorch/pytorch/issues/108079
Undone by `remove_input_alias_fixing_clones` after tracing
"""
# Extract graph placeholder Tensors
placeholders = get_tensor_placeholders(gm)

for node in placeholders:
# Insert clones for placeholder nodes to avoid
# input aliasing or mutation
with gm.graph.inserting_after(placeholders[-1]):
cloned_input = gm.graph.call_function(
torch.ops.aten.clone.default,
args=(node,),
)

# Replace all uses of the placeholder except the cloned node
# with the cloned placeholder
node.replace_all_uses_with(
cloned_input,
delete_user_cb=lambda node: node != cloned_input,
)

gm.graph.lint()
gm.recompile()
logger.debug(f"Inserted auxiliary clone nodes for placeholders:\n{gm.graph}")

return gm
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from .constant_folding import constant_fold
from .pass_manager import DynamoPassManager
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output

ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
[
remove_input_alias_fixing_clones,
constant_fold,
repair_input_as_output,
]
Expand Down
7 changes: 4 additions & 3 deletions py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import torch
from torch_tensorrt._utils import sanitized_torch_version
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

from packaging import version

Expand Down Expand Up @@ -47,9 +50,7 @@ def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in erased_params:
gm.graph.erase_node(node)

gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
gm = clean_up_graph_after_modifications(gm)

logger.debug(f"Graph after constant folding:\n{gm.graph}")

Expand Down
31 changes: 31 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import List

import torch


def clean_up_graph_after_modifications(
gm: torch.fx.GraphModule,
) -> torch.fx.GraphModule:
"""Runs dead-code elimination, linting, and recompilation for graph, in-place"""
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
return gm


def get_tensor_placeholders(
gm: torch.fx.GraphModule,
) -> List[torch.fx.Node]:
"""Returns placeholder nodes of GraphModule which are torch.Tensor types"""
# Tensor placeholders must be subclasses of torch.Tensor
placeholders = [
node
for node in gm.graph.nodes
if (
node.op == "placeholder"
and isinstance(node.type, type)
and issubclass(node.type, torch.Tensor)
)
]

return placeholders
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import logging

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)


# TODO: Delete this lowering pass once aot_export_joint_simple is patched
def remove_input_alias_fixing_clones(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""Remove the auxiliary clone nodes inserted to fix input aliasing

See: https://github.com/pytorch/pytorch/issues/108079
"""
modified_graph = False

for node in gm.graph.nodes:
# If the node is a placeholder and its only user is a clone node
# it was modified by the input alias-fixing pass, and the change
# needs to be undone
if (
node.op == "placeholder"
and len(node.users) == 1
and list(node.users)[0].target == torch.ops.aten.clone.default
):
modified_graph = True

# Replace all uses of the clone with the placholder, delete the clone
clone_node = list(node.users)[0]
logger.debug(
f"Removing node {clone_node} from graph, since it is a clone node which "
f"is the only user of placeholder {node} and was inserted by the compiler."
)
clone_node.replace_all_uses_with(node)
gm.graph.erase_node(clone_node)

if modified_graph:
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Removed auxiliary clone nodes for placeholders:\n{gm.graph}")

return gm
20 changes: 7 additions & 13 deletions py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import logging

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
get_tensor_placeholders,
)

logger = logging.getLogger(__name__)

Expand All @@ -13,15 +17,7 @@ def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
modified_graph = False

# Extract graph placeholder Tensors
placeholders = [
node
for node in gm.graph.nodes
if (
node.op == "placeholder"
and isinstance(node.type, type)
and issubclass(node.type, torch.Tensor)
)
]
placeholders = get_tensor_placeholders(gm)

for placeholder in placeholders:
# If any placeholder has any users which are direct graph outputs
Expand All @@ -34,7 +30,7 @@ def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
direct_outputs = [user for user in placeholder.users if user.op == "output"]

# Insert clone node for placeholder to ensure placeholder is not a direct output
with gm.graph.inserting_after(placeholder):
with gm.graph.inserting_after(placeholders[-1]):
cloned_placeholder = gm.graph.call_function(
torch.ops.aten.clone.default,
args=(placeholder,),
Expand All @@ -45,9 +41,7 @@ def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
output.replace_input_with(placeholder, cloned_placeholder)

if modified_graph:
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}")

return gm
Loading