Skip to content

Commit 08a2ee4

Browse files
authored
feat: Improve Logging in Dynamo (#2194)
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
1 parent b57d83e commit 08a2ee4

File tree

5 files changed

+142
-150
lines changed

5 files changed

+142
-150
lines changed

py/torch_tensorrt/_compile.py

+17-15
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from __future__ import annotations
22

3+
import logging
34
from enum import Enum
45
from typing import Any, Callable, List, Optional, Sequence, Set
56

67
import torch
78
import torch.fx
89
import torch_tensorrt.ts
9-
from torch_tensorrt import logging
1010
from torch_tensorrt._enums import dtype
1111
from torch_tensorrt._Input import Input
1212
from torch_tensorrt.dynamo.compile import compile as dynamo_compile
@@ -16,6 +16,13 @@
1616
from torch_tensorrt.ts._compiler import compile as torchscript_compile
1717
from typing_extensions import TypeGuard
1818

19+
logger = logging.getLogger(__name__)
20+
21+
__all__ = [
22+
"compile",
23+
"convert_method_to_trt_engine",
24+
]
25+
1926

2027
def _non_fx_input_interface(
2128
inputs: Sequence[Input | torch.Tensor | InputTensorSpec],
@@ -30,7 +37,7 @@ def _fx_input_interface(
3037

3138

3239
class _IRType(Enum):
33-
"""Enum to set the minimum required logging level to print a message to stdout"""
40+
"""Enum to determine the type of IR selected for model compilation"""
3441

3542
ts = 0
3643
fx = 1
@@ -39,7 +46,7 @@ class _IRType(Enum):
3946

4047

4148
class _ModuleType(Enum):
42-
"""Enum to set the minimum required logging level to print a message to stdout"""
49+
"""Enum to determine the type of model provided as input"""
4350

4451
nn = 0
4552
ts = 1
@@ -81,14 +88,11 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
8188
if ir == "default":
8289
# Options are listed in order of preference
8390
if module_is_fxable:
84-
logging.log(
85-
logging.Level.Info, "ir was set to default, using dynamo as ir"
86-
)
91+
logger.info("ir was set to default, using dynamo as ir")
8792
return _IRType.dynamo
8893
elif module_is_tsable:
89-
logging.log(
90-
logging.Level.Warning,
91-
"Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to suppress the warning. Compiling the module with ir=torchscript",
94+
logger.warning(
95+
"Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to suppress the warning. Compiling the module with ir=torchscript"
9296
)
9397
return _IRType.ts
9498
else:
@@ -151,9 +155,8 @@ def compile(
151155
if target_ir == _IRType.ts:
152156
ts_mod = module
153157
if module_type == _ModuleType.nn:
154-
logging.log(
155-
logging.Level.Info,
156-
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript",
158+
logger.info(
159+
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript"
157160
)
158161
ts_mod = torch.jit.script(module)
159162
assert _non_fx_input_interface(input_list)
@@ -274,9 +277,8 @@ def convert_method_to_trt_engine(
274277
if target_ir == _IRType.ts:
275278
ts_mod = module
276279
if module_type == _ModuleType.nn:
277-
logging.log(
278-
logging.Level.Info,
279-
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript",
280+
logger.info(
281+
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript"
280282
)
281283
ts_mod = torch.jit.script(module)
282284
return torch_tensorrt.ts.convert_method_to_trt_engine( # type: ignore[no-any-return]

py/torch_tensorrt/dynamo/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
import logging
2+
13
from torch_tensorrt._utils import sanitized_torch_version
24

35
from packaging import version
46

7+
logger = logging.getLogger(__name__)
8+
59
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
610
from ._settings import * # noqa: F403
711
from ._SourceIR import SourceIR # noqa: F403

py/torch_tensorrt/dynamo/aten_tracer.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import copy
4+
import logging
45
import sys
56
from contextlib import contextmanager
67
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
@@ -26,6 +27,8 @@
2627

2728
Value: TypeAlias = Union[Tuple["Value", ...], List["Value"], Dict[str, "Value"]]
2829

30+
logger = logging.getLogger(__name__)
31+
2932

3033
class DynamoConfig:
3134
"""
@@ -145,13 +148,13 @@ def trace(
145148
]
146149

147150
fx_module, __package__ = dynamo_trace(model, inputs, True, "symbolic")
148-
print(fx_module.graph)
151+
149152
for passes in passes_list:
150153
pr: PassResult = passes(fx_module)
151154
fx_module = pr.graph_module
152155

153156
fx_module(*inputs)
154157

155158
fx_module = run_const_fold(fx_module)
156-
print(fx_module.graph)
159+
logger.info("Post export graph : %s\n", fx_module.graph)
157160
return fx_module

py/torch_tensorrt/dynamo/backend/backends.py

+3-115
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,8 @@
77
import torch
88
import torch._dynamo as td
99
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
10-
from torch_tensorrt.dynamo import CompilationSettings, partitioning
11-
from torch_tensorrt.dynamo.conversion import (
12-
convert_module,
13-
repair_long_or_double_inputs,
14-
)
10+
from torch_tensorrt.dynamo import CompilationSettings
11+
from torch_tensorrt.dynamo.compile import compile_module
1512
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
1613
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1714
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
@@ -69,7 +66,7 @@ def _pretraced_backend(
6966
try:
7067
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
7168

72-
trt_compiled = _compile_module(
69+
trt_compiled = compile_module(
7370
gm,
7471
sample_inputs,
7572
settings=settings,
@@ -92,112 +89,3 @@ def _pretraced_backend(
9289
+ "specify pass_through_build_failures=False."
9390
)
9491
raise
95-
96-
97-
def _compile_module(
98-
gm: torch.fx.GraphModule,
99-
sample_inputs: Sequence[torch.Tensor],
100-
settings: CompilationSettings = CompilationSettings(),
101-
) -> torch.fx.GraphModule:
102-
"""Compile a traced FX module
103-
104-
Includes: Partitioning + Conversion Phases
105-
106-
Args:
107-
module: FX GraphModule to convert
108-
inputs: Inputs to the module
109-
settings: Compilation settings
110-
Returns:
111-
Compiled FX GraphModule
112-
"""
113-
# Check the number of supported operations in the graph
114-
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
115-
gm, settings.debug, settings.torch_executed_ops
116-
)
117-
118-
# If the number of supported operations is 0 or less than the block size, skip the subgraph
119-
# TODO: Add condition to second expression below when require_full_compilation is added
120-
if num_supported_ops == 0 or (num_supported_ops < settings.min_block_size):
121-
logger.warning(
122-
f"{num_supported_ops} supported operations detected in subgraph containing {total_ops} computational nodes. "
123-
f"Skipping this subgraph, since min_block_size was detected to be {settings.min_block_size}"
124-
)
125-
return gm
126-
else:
127-
logger.debug(
128-
f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph."
129-
)
130-
131-
# Partition module into components that can be TRT-accelerated
132-
fast_partitioner_failed = False
133-
134-
# If specified, try using the fast partitioner and fall back to the global one on failure
135-
if settings.use_fast_partitioner:
136-
try:
137-
partitioned_module = partitioning.fast_partition(
138-
gm,
139-
verbose=settings.debug,
140-
min_block_size=settings.min_block_size,
141-
torch_executed_ops=settings.torch_executed_ops,
142-
)
143-
except torch.fx.passes.splitter_base.FxNetSplitterInternalError:
144-
logger.error(
145-
"Partitioning failed on the subgraph with fast partition. See trace above. "
146-
+ "Retrying with global partition.",
147-
exc_info=True,
148-
)
149-
150-
fast_partitioner_failed = True
151-
settings.use_fast_partitioner = False
152-
153-
if not settings.use_fast_partitioner:
154-
partitioned_module = partitioning.global_partition(
155-
gm,
156-
verbose=settings.debug,
157-
min_block_size=settings.min_block_size,
158-
torch_executed_ops=settings.torch_executed_ops,
159-
)
160-
161-
# Store TRT replicas of Torch subgraphs
162-
trt_modules = {}
163-
164-
# Iterate over all components that can be accelerated
165-
# Generate the corresponding TRT Module for those
166-
for name, _ in partitioned_module.named_children():
167-
# Criteria for a module to be convertible to TRT
168-
if settings.use_fast_partitioner and "_run_on_acc" not in name:
169-
continue
170-
171-
submodule = getattr(partitioned_module, name)
172-
173-
# Get submodule inputs
174-
submodule_inputs = partitioning.get_submod_inputs(
175-
partitioned_module, submodule, sample_inputs
176-
)
177-
178-
assert submodule_inputs is not None
179-
# Handle long/double inputs if requested by the user
180-
if settings.truncate_long_and_double:
181-
submodule_inputs = repair_long_or_double_inputs(
182-
partitioned_module, submodule, submodule_inputs, name
183-
)
184-
185-
# Create TRT Module from submodule
186-
trt_mod = convert_module(
187-
submodule,
188-
submodule_inputs,
189-
settings=settings,
190-
name=name,
191-
)
192-
193-
trt_modules[name] = trt_mod
194-
195-
# Replace all FX Modules with TRT Modules
196-
for name, trt_mod in trt_modules.items():
197-
setattr(partitioned_module, name, trt_mod)
198-
199-
# Reset settings object to user specification after fallback to global partitioning mode
200-
if fast_partitioner_failed:
201-
settings.use_fast_partitioner = True
202-
203-
return partitioned_module

0 commit comments

Comments
 (0)