|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -import copy |
4 | 3 | import logging |
5 | | -import sys |
6 | | -from contextlib import contextmanager |
7 | | -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union |
| 4 | +import unittest.mock |
| 5 | +from typing import Any, Tuple |
8 | 6 |
|
9 | 7 | import torch |
10 | | -import torch._dynamo as torchdynamo |
11 | | -from torch.fx.passes.infra.pass_base import PassResult |
12 | | -from torch_tensorrt.dynamo.utils import req_torch_version |
13 | | -from torch_tensorrt.fx.passes.lower_basic_pass_aten import ( |
14 | | - compose_bmm, |
15 | | - compose_chunk, |
16 | | - compose_getitem_slice, |
17 | | - remove_ops, |
18 | | - replace_aten_op_with_indices, |
19 | | - replace_aten_reshape_alias_with_replace, |
20 | | - replace_builtin_ops, |
21 | | - replace_inplace_ops, |
22 | | - replace_native_layernorm_with_layernorm, |
23 | | - replace_transpose_mm_op_with_linear, |
24 | | - run_const_fold, |
25 | | -) |
26 | | -from typing_extensions import TypeAlias |
27 | | - |
28 | | -Value: TypeAlias = Union[Tuple["Value", ...], List["Value"], Dict[str, "Value"]] |
| 8 | +from torch._export import export |
| 9 | +from torch_tensorrt.dynamo.backend.backends import constant_fold |
| 10 | +from torch_tensorrt.dynamo.lowering import get_decompositions |
| 11 | +from torch_tensorrt.dynamo.utils import set_log_level |
29 | 12 |
|
30 | 13 | logger = logging.getLogger(__name__) |
31 | 14 |
|
32 | 15 |
|
33 | | -class DynamoConfig: |
34 | | - """ |
35 | | - Manage Exir-specific configurations of Dynamo. |
36 | | - """ |
37 | | - |
38 | | - def __init__( |
39 | | - self, |
40 | | - capture_scalar_outputs: bool = True, |
41 | | - guard_nn_modules: bool = True, |
42 | | - dynamic_shapes: bool = True, |
43 | | - specialize_int: bool = True, |
44 | | - verbose: bool = True, |
45 | | - ) -> None: |
46 | | - self.capture_scalar_outputs = capture_scalar_outputs |
47 | | - self.guard_nn_modules = guard_nn_modules |
48 | | - self.dynamic_shapes = dynamic_shapes |
49 | | - self.specialize_int = specialize_int |
50 | | - self.verbose = verbose |
51 | | - |
52 | | - def activate(self) -> None: |
53 | | - torchdynamo.config.capture_scalar_outputs = self.capture_scalar_outputs |
54 | | - torchdynamo.config.guard_nn_modules = self.guard_nn_modules |
55 | | - torchdynamo.config.dynamic_shapes = self.dynamic_shapes |
56 | | - torchdynamo.config.specialize_int = self.specialize_int |
57 | | - torchdynamo.config.verbose = self.verbose |
58 | | - |
59 | | - def deactivate(self) -> None: |
60 | | - torchdynamo.config.capture_scalar_outputs = True |
61 | | - torchdynamo.config.guard_nn_modules = True |
62 | | - torchdynamo.config.dynamic_shapes = True |
63 | | - torchdynamo.config.specialize_int = True |
64 | | - torchdynamo.config.verbose = True |
65 | | - |
66 | | - |
67 | | -@contextmanager |
68 | | -def using_config(config: DynamoConfig) -> Generator[DynamoConfig, None, None]: |
69 | | - config.activate() |
70 | | - try: |
71 | | - yield config |
72 | | - finally: |
73 | | - config.deactivate() |
74 | | - |
75 | | - |
76 | | -@contextmanager |
77 | | -def setting_python_recursive_limit(limit: int = 10000) -> Generator[None, None, None]: |
78 | | - """ |
79 | | - Temporarily increase the python interpreter stack recursion limit. |
80 | | - This is mostly used for pickling large scale modules. |
81 | | - """ |
82 | | - default = sys.getrecursionlimit() |
83 | | - if limit > default: |
84 | | - sys.setrecursionlimit(limit) |
85 | | - try: |
86 | | - yield |
87 | | - finally: |
88 | | - sys.setrecursionlimit(default) |
89 | | - |
90 | | - |
91 | | -@req_torch_version("2.dev") |
92 | | -def dynamo_trace( |
93 | | - f: Callable[..., Value], |
94 | | - # pyre-ignore |
95 | | - args: Tuple[Any, ...], |
96 | | - aten_graph: bool, |
97 | | - tracing_mode: str = "real", |
98 | | - dynamo_config: Optional[DynamoConfig] = None, |
99 | | -) -> Any: # Tuple[torch.fx.GraphModule, Set[_guards.Guard]]: |
100 | | - """ |
101 | | - TODO: Once we fully migrate to torchdynamo frontend, we will remove |
102 | | - this config option alltogether. For now, it helps with quick |
103 | | - experiments with playing around with TorchDynamo |
104 | | - """ |
105 | | - if dynamo_config is None: |
106 | | - dynamo_config = DynamoConfig() |
107 | | - with using_config(dynamo_config), setting_python_recursive_limit(2000): |
108 | | - torchdynamo.reset() |
109 | | - try: |
110 | | - return torchdynamo.export( |
111 | | - f, |
112 | | - *copy.deepcopy(args), |
113 | | - aten_graph=aten_graph, |
114 | | - tracing_mode=tracing_mode, |
115 | | - ) |
116 | | - except torchdynamo.exc.Unsupported as exc: |
117 | | - raise RuntimeError( |
118 | | - "The user code is using a feature we don't support. " |
119 | | - "Please try torchdynamo.explain() to get possible the reasons", |
120 | | - ) from exc |
121 | | - except Exception as exc: |
122 | | - raise RuntimeError( |
123 | | - "torchdynamo internal error occured. Please see above stacktrace" |
124 | | - ) from exc |
125 | | - |
126 | | - |
127 | | -@req_torch_version("2.dev") |
128 | 16 | def trace( |
129 | 17 | model: torch.nn.Module | torch.fx.GraphModule, |
130 | 18 | inputs: Tuple[Any, ...], |
131 | 19 | **kwargs: Any, |
132 | 20 | ) -> torch.fx.GraphModule: |
133 | | - """ |
134 | | - Optimized trace with necessary passes which re-compose some ops or replace some ops |
135 | | - These passes should be general and functional purpose |
136 | | - """ |
137 | | - passes_list = [ |
138 | | - compose_bmm, |
139 | | - compose_chunk, |
140 | | - compose_getitem_slice, |
141 | | - replace_aten_reshape_alias_with_replace, |
142 | | - replace_aten_op_with_indices, |
143 | | - replace_transpose_mm_op_with_linear, # after compose_bmm |
144 | | - replace_native_layernorm_with_layernorm, |
145 | | - remove_ops, |
146 | | - replace_builtin_ops, # after replace_native_layernorm_with_layernorm |
147 | | - replace_inplace_ops, # remove it once functionalization is enabled |
148 | | - ] |
149 | | - |
150 | | - fx_module, __package__ = dynamo_trace(model, inputs, True, "symbolic") |
151 | | - |
152 | | - for passes in passes_list: |
153 | | - pr: PassResult = passes(fx_module) |
154 | | - fx_module = pr.graph_module |
155 | | - |
156 | | - fx_module(*inputs) |
157 | | - |
158 | | - fx_module = run_const_fold(fx_module) |
159 | | - logger.info("Post export graph : %s\n", fx_module.graph) |
160 | | - return fx_module |
| 21 | + # Set log level at the top of compilation (torch_tensorrt.dynamo) |
| 22 | + if "debug" in kwargs and kwargs["debug"]: |
| 23 | + set_log_level(logger.parent, logging.DEBUG) |
| 24 | + |
| 25 | + experimental_decompositions = kwargs.get( |
| 26 | + "enable_experimental_decompositions", False |
| 27 | + ) |
| 28 | + with unittest.mock.patch( |
| 29 | + "torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions) |
| 30 | + ): |
| 31 | + graph_module = export(model, tuple(inputs)).module() |
| 32 | + constant_fold(graph_module) |
| 33 | + logger.debug("Post export graph: " + str(graph_module.graph)) |
| 34 | + return graph_module |
0 commit comments