Skip to content

Commit 4e308f1

Browse files
committed
Custom implementation of AOT for compile
1 parent a94a075 commit 4e308f1

File tree

1 file changed

+52
-4
lines changed

1 file changed

+52
-4
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
import logging
44
import unittest
5-
from typing import Any, Callable, Sequence
5+
from typing import Any, Callable, Dict, Optional, Sequence
66

77
import torch
88
import torch._dynamo as td
9+
import torch.utils._pytree as pytree
910
from torch._dynamo.utils import detect_fake_mode
10-
from torch._functorch.aot_autograd import aot_export_joint_simple
11+
from torch._functorch.aot_autograd import _aot_export_function
1112
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
13+
from torch._ops import OpOverload
1214
from torch_tensorrt.dynamo import CompilationSettings
1315
from torch_tensorrt.dynamo.compile import compile_module
1416
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
@@ -73,10 +75,9 @@ def _pretraced_backend(
7375
fake_mode, "allow_non_fake_inputs", True
7476
), fake_mode:
7577
# Invoke AOTAutograd to translate operators to aten
76-
graph_module = aot_export_joint_simple(
78+
graph_module = aot_export_for_compile(
7779
gm,
7880
sample_inputs,
79-
trace_joint=False,
8081
decompositions=get_decompositions(
8182
settings.enable_experimental_decompositions
8283
),
@@ -131,3 +132,50 @@ def constant_fold(gm: torch.fx.GraphModule) -> Any:
131132
gm.graph.eliminate_dead_code()
132133
gm.graph.lint()
133134
gm.recompile()
135+
136+
137+
def aot_export_for_compile(
138+
func: torch.fx.GraphModule,
139+
args: Sequence[torch.Tensor],
140+
*,
141+
decompositions: Optional[Dict[OpOverload, Callable[[Any], Any]]] = None,
142+
) -> torch.fx.GraphModule:
143+
"""Adapted from:
144+
https://github.com/pytorch/pytorch/blob/054f3f1d8f9eb63ef8437991eba5b8f2aeee920f/torch/_functorch/aot_autograd.py#L4133-L4134
145+
146+
Removed check for input aliasing in resultant subgraph - TRT is functional-only
147+
"""
148+
with torch.no_grad():
149+
fx_g, metadata, in_spec, out_spec = _aot_export_function(
150+
func,
151+
args,
152+
decompositions=decompositions,
153+
)
154+
155+
# No input mutations
156+
if (
157+
len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata])
158+
!= 0
159+
):
160+
raise RuntimeError(
161+
f"aot_export_joint_simple does not support input mutations. {str(metadata)}"
162+
)
163+
# No pytrees
164+
if type(in_spec) == pytree.LeafSpec:
165+
raise RuntimeError(
166+
f"aot_export_for_compile requires inputs to be a single list/tuple. in_spec={str(in_spec)}"
167+
)
168+
if len([x for x in in_spec.children_specs if type(x) != pytree.LeafSpec]) != 0:
169+
raise RuntimeError(
170+
f"aot_export_for_compile requires individual inputs not to be pytrees. in_spec={str(in_spec)}"
171+
)
172+
if type(out_spec) == pytree.LeafSpec:
173+
raise RuntimeError(
174+
f"aot_export_for_compile requires outputs to be a single list/tuple. out_spec={str(out_spec)}"
175+
)
176+
if len([x for x in out_spec.children_specs if type(x) != pytree.LeafSpec]) != 0:
177+
raise RuntimeError(
178+
f"aot_export_for_compile requires individual outputs not to be pytrees. out_spec={str(out_spec)}"
179+
)
180+
181+
return fx_g

0 commit comments

Comments
 (0)