|
2 | 2 |
|
3 | 3 | import logging
|
4 | 4 | import unittest
|
5 |
| -from typing import Any, Callable, Sequence |
| 5 | +from typing import Any, Callable, Dict, Optional, Sequence |
6 | 6 |
|
7 | 7 | import torch
|
8 | 8 | import torch._dynamo as td
|
| 9 | +import torch.utils._pytree as pytree |
9 | 10 | from torch._dynamo.utils import detect_fake_mode
|
10 |
| -from torch._functorch.aot_autograd import aot_export_joint_simple |
| 11 | +from torch._functorch.aot_autograd import _aot_export_function |
11 | 12 | from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
|
| 13 | +from torch._ops import OpOverload |
12 | 14 | from torch_tensorrt.dynamo import CompilationSettings
|
13 | 15 | from torch_tensorrt.dynamo.compile import compile_module
|
14 | 16 | from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
|
@@ -73,10 +75,9 @@ def _pretraced_backend(
|
73 | 75 | fake_mode, "allow_non_fake_inputs", True
|
74 | 76 | ), fake_mode:
|
75 | 77 | # Invoke AOTAutograd to translate operators to aten
|
76 |
| - graph_module = aot_export_joint_simple( |
| 78 | + graph_module = aot_export_for_compile( |
77 | 79 | gm,
|
78 | 80 | sample_inputs,
|
79 |
| - trace_joint=False, |
80 | 81 | decompositions=get_decompositions(
|
81 | 82 | settings.enable_experimental_decompositions
|
82 | 83 | ),
|
@@ -131,3 +132,50 @@ def constant_fold(gm: torch.fx.GraphModule) -> Any:
|
131 | 132 | gm.graph.eliminate_dead_code()
|
132 | 133 | gm.graph.lint()
|
133 | 134 | gm.recompile()
|
| 135 | + |
| 136 | + |
| 137 | +def aot_export_for_compile( |
| 138 | + func: torch.fx.GraphModule, |
| 139 | + args: Sequence[torch.Tensor], |
| 140 | + *, |
| 141 | + decompositions: Optional[Dict[OpOverload, Callable[[Any], Any]]] = None, |
| 142 | +) -> torch.fx.GraphModule: |
| 143 | + """Adapted from: |
| 144 | + https://github.com/pytorch/pytorch/blob/054f3f1d8f9eb63ef8437991eba5b8f2aeee920f/torch/_functorch/aot_autograd.py#L4133-L4134 |
| 145 | +
|
| 146 | + Removed check for input aliasing in resultant subgraph - TRT is functional-only |
| 147 | + """ |
| 148 | + with torch.no_grad(): |
| 149 | + fx_g, metadata, in_spec, out_spec = _aot_export_function( |
| 150 | + func, |
| 151 | + args, |
| 152 | + decompositions=decompositions, |
| 153 | + ) |
| 154 | + |
| 155 | + # No input mutations |
| 156 | + if ( |
| 157 | + len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata]) |
| 158 | + != 0 |
| 159 | + ): |
| 160 | + raise RuntimeError( |
| 161 | + f"aot_export_joint_simple does not support input mutations. {str(metadata)}" |
| 162 | + ) |
| 163 | + # No pytrees |
| 164 | + if type(in_spec) == pytree.LeafSpec: |
| 165 | + raise RuntimeError( |
| 166 | + f"aot_export_for_compile requires inputs to be a single list/tuple. in_spec={str(in_spec)}" |
| 167 | + ) |
| 168 | + if len([x for x in in_spec.children_specs if type(x) != pytree.LeafSpec]) != 0: |
| 169 | + raise RuntimeError( |
| 170 | + f"aot_export_for_compile requires individual inputs not to be pytrees. in_spec={str(in_spec)}" |
| 171 | + ) |
| 172 | + if type(out_spec) == pytree.LeafSpec: |
| 173 | + raise RuntimeError( |
| 174 | + f"aot_export_for_compile requires outputs to be a single list/tuple. out_spec={str(out_spec)}" |
| 175 | + ) |
| 176 | + if len([x for x in out_spec.children_specs if type(x) != pytree.LeafSpec]) != 0: |
| 177 | + raise RuntimeError( |
| 178 | + f"aot_export_for_compile requires individual outputs not to be pytrees. out_spec={str(out_spec)}" |
| 179 | + ) |
| 180 | + |
| 181 | + return fx_g |
0 commit comments