forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_remove_auto_functionalized_pass.py
More file actions
55 lines (47 loc) · 2.06 KB
/
_remove_auto_functionalized_pass.py
File metadata and controls
55 lines (47 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch._higher_order_ops.auto_functionalize import (
auto_functionalized,
auto_functionalized_v2,
)
from torch._inductor.fx_passes.post_grad import decompose_auto_functionalized
from torch.export import ExportedProgram
from torch.fx import Graph
def remove_self_clone(graph: Graph) -> None:
for node in graph.nodes:
if node.target is torch.ops.aten.copy_.default and node.args[0] == node.args[1]:
node.replace_all_uses_with(node.args[0])
graph.erase_node(node)
def unsafe_remove_auto_functionalized_pass(
ep: ExportedProgram,
) -> ExportedProgram:
"""
This pass removes an instances of the higher order op 'auto_functionalized',
and modifies the calling EP inplace to have the original mutator op.
This pass doesn't perform safety checks to make sure that this inplace mutation is safe.
"""
with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()):
for module in ep.graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in ep.graph.nodes:
if (
node.op == "call_function" and node.target is auto_functionalized
) or (
node.op == "call_function" and node.target is auto_functionalized_v2
):
func = node.args[0]
if not isinstance(func, torch._ops.OpOverload):
raise AssertionError(
f"Expected func to be an OpOverload, but got {type(func)}"
)
# re-inplace everything
node.meta["only_clone_these_tensors"] = []
decompose_auto_functionalized(ep.graph)
remove_self_clone(ep.graph)
ep.graph.eliminate_dead_code()
return ep