7
7
import torch
8
8
import torch ._dynamo as td
9
9
import torch .utils ._pytree as pytree
10
- import torch_tensorrt
11
10
from torch ._dynamo .utils import detect_fake_mode
12
11
from torch ._functorch .aot_autograd import _aot_export_function
13
12
from torch ._ops import OpOverload
14
13
from torch_tensorrt .dynamo import CompilationSettings
15
14
from torch_tensorrt .dynamo .compile import compile_module
16
- from torch_tensorrt .dynamo .lowering . _decompositions import get_decompositions
15
+ from torch_tensorrt .dynamo .lowering import apply_lowering_passes , get_decompositions
17
16
from torch_tensorrt .dynamo .lowering ._pre_aot_lowering import pre_aot_substitutions
18
17
from torch_tensorrt .dynamo .utils import parse_dynamo_kwargs
19
18
20
- from packaging import version
21
-
22
- # Modify import location of utilities based on Torch version
23
- if version .parse (torch_tensorrt .sanitized_torch_version ()) <= version .parse ("2.1.0" ):
24
- from torch ._inductor .freezing import ConstantFolder , replace_node_with_constant
25
- else :
26
- from torch ._inductor .constant_folding import (
27
- ConstantFolder ,
28
- replace_node_with_constant ,
29
- )
30
-
31
19
logger = logging .getLogger (__name__ )
32
20
33
21
@@ -86,7 +74,7 @@ def _pretraced_backend(
86
74
fake_mode , "allow_non_fake_inputs" , True
87
75
), fake_mode :
88
76
# Invoke AOTAutograd to translate operators to aten
89
- graph_module = aot_export_for_compile (
77
+ gm = aot_export_for_compile (
90
78
gm ,
91
79
sample_inputs ,
92
80
decompositions = get_decompositions (
@@ -96,10 +84,10 @@ def _pretraced_backend(
96
84
97
85
logger .debug ("Post-AOT Autograd graph:\n " + str (gm .graph ))
98
86
99
- constant_fold ( graph_module )
87
+ gm = apply_lowering_passes ( gm )
100
88
101
89
trt_compiled = compile_module (
102
- graph_module ,
90
+ gm ,
103
91
sample_inputs ,
104
92
settings = settings ,
105
93
)
@@ -123,35 +111,6 @@ def _pretraced_backend(
123
111
raise
124
112
125
113
126
- @torch .utils ._python_dispatch ._disable_current_modes () # type: ignore
127
- def constant_fold (gm : torch .fx .GraphModule ) -> Any :
128
- """Adapted from:
129
- https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
130
-
131
- Folds constants in the graph module, not skipping constructors
132
-
133
- Modifies the graph in-place and replaces node with constants
134
- """
135
- cf = ConstantFolder (gm , skip_constructors = False )
136
- cf .run ()
137
-
138
- for node , constant in cf .node_replacements .items ():
139
- replace_node_with_constant (gm , node , constant )
140
-
141
- erased_params = []
142
- for node in gm .graph .nodes :
143
- if node .op == "get_attr" and len (node .users ) == 0 :
144
- delattr (gm , node .target )
145
- erased_params .append (node )
146
-
147
- for node in erased_params :
148
- gm .graph .erase_node (node )
149
-
150
- gm .graph .eliminate_dead_code ()
151
- gm .graph .lint ()
152
- gm .recompile ()
153
-
154
-
155
114
def aot_export_for_compile (
156
115
func : torch .fx .GraphModule ,
157
116
args : Sequence [torch .Tensor ],
0 commit comments