Skip to content

feat: Improve Logging in Dynamo #2194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

import logging
from enum import Enum
from typing import Any, Callable, List, Optional, Sequence, Set

import torch
import torch.fx
import torch_tensorrt.ts
from torch_tensorrt import logging
from torch_tensorrt._enums import dtype
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo.compile import compile as dynamo_compile
Expand All @@ -16,6 +16,13 @@
from torch_tensorrt.ts._compiler import compile as torchscript_compile
from typing_extensions import TypeGuard

logger = logging.getLogger(__name__)

__all__ = [
"compile",
"convert_method_to_trt_engine",
]


def _non_fx_input_interface(
inputs: Sequence[Input | torch.Tensor | InputTensorSpec],
Expand All @@ -30,7 +37,7 @@ def _fx_input_interface(


class _IRType(Enum):
"""Enum to set the minimum required logging level to print a message to stdout"""
"""Enum to determine the type of IR selected for model compilation"""

ts = 0
fx = 1
Expand All @@ -39,7 +46,7 @@ class _IRType(Enum):


class _ModuleType(Enum):
"""Enum to set the minimum required logging level to print a message to stdout"""
"""Enum to determine the type of model provided as input"""

nn = 0
ts = 1
Expand Down Expand Up @@ -81,14 +88,11 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
if ir == "default":
# Options are listed in order of preference
if module_is_fxable:
logging.log(
logging.Level.Info, "ir was set to default, using dynamo as ir"
)
logger.info("ir was set to default, using dynamo as ir")
return _IRType.dynamo
elif module_is_tsable:
logging.log(
logging.Level.Warning,
"Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to suppress the warning. Compiling the module with ir=torchscript",
logger.warning(
"Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to suppress the warning. Compiling the module with ir=torchscript"
)
return _IRType.ts
else:
Expand Down Expand Up @@ -151,9 +155,8 @@ def compile(
if target_ir == _IRType.ts:
ts_mod = module
if module_type == _ModuleType.nn:
logging.log(
logging.Level.Info,
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript",
logger.info(
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript"
)
ts_mod = torch.jit.script(module)
assert _non_fx_input_interface(input_list)
Expand Down Expand Up @@ -274,9 +277,8 @@ def convert_method_to_trt_engine(
if target_ir == _IRType.ts:
ts_mod = module
if module_type == _ModuleType.nn:
logging.log(
logging.Level.Info,
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript",
logger.info(
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript"
)
ts_mod = torch.jit.script(module)
return torch_tensorrt.ts.convert_method_to_trt_engine( # type: ignore[no-any-return]
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import logging

from torch_tensorrt._utils import sanitized_torch_version

from packaging import version

logger = logging.getLogger(__name__)

if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
from ._settings import * # noqa: F403
from ._SourceIR import SourceIR # noqa: F403
Expand Down
7 changes: 5 additions & 2 deletions py/torch_tensorrt/dynamo/aten_tracer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
import logging
import sys
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
Expand All @@ -26,6 +27,8 @@

Value: TypeAlias = Union[Tuple["Value", ...], List["Value"], Dict[str, "Value"]]

logger = logging.getLogger(__name__)


class DynamoConfig:
"""
Expand Down Expand Up @@ -145,13 +148,13 @@ def trace(
]

fx_module, __package__ = dynamo_trace(model, inputs, True, "symbolic")
print(fx_module.graph)

for passes in passes_list:
pr: PassResult = passes(fx_module)
fx_module = pr.graph_module

fx_module(*inputs)

fx_module = run_const_fold(fx_module)
print(fx_module.graph)
logger.info("Post export graph : %s\n", fx_module.graph)
return fx_module
118 changes: 3 additions & 115 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@
import torch
import torch._dynamo as td
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
from torch_tensorrt.dynamo import CompilationSettings, partitioning
from torch_tensorrt.dynamo.conversion import (
convert_module,
repair_long_or_double_inputs,
)
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo.compile import compile_module
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
Expand Down Expand Up @@ -69,7 +66,7 @@ def _pretraced_backend(
try:
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))

trt_compiled = _compile_module(
trt_compiled = compile_module(
gm,
sample_inputs,
settings=settings,
Expand All @@ -92,112 +89,3 @@ def _pretraced_backend(
+ "specify pass_through_build_failures=False."
)
raise


def _compile_module(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
settings: CompilationSettings = CompilationSettings(),
) -> torch.fx.GraphModule:
"""Compile a traced FX module

Includes: Partitioning + Conversion Phases

Args:
module: FX GraphModule to convert
inputs: Inputs to the module
settings: Compilation settings
Returns:
Compiled FX GraphModule
"""
# Check the number of supported operations in the graph
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
gm, settings.debug, settings.torch_executed_ops
)

# If the number of supported operations is 0 or less than the block size, skip the subgraph
# TODO: Add condition to second expression below when require_full_compilation is added
if num_supported_ops == 0 or (num_supported_ops < settings.min_block_size):
logger.warning(
f"{num_supported_ops} supported operations detected in subgraph containing {total_ops} computational nodes. "
f"Skipping this subgraph, since min_block_size was detected to be {settings.min_block_size}"
)
return gm
else:
logger.debug(
f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph."
)

# Partition module into components that can be TRT-accelerated
fast_partitioner_failed = False

# If specified, try using the fast partitioner and fall back to the global one on failure
if settings.use_fast_partitioner:
try:
partitioned_module = partitioning.fast_partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
)
except torch.fx.passes.splitter_base.FxNetSplitterInternalError:
logger.error(
"Partitioning failed on the subgraph with fast partition. See trace above. "
+ "Retrying with global partition.",
exc_info=True,
)

fast_partitioner_failed = True
settings.use_fast_partitioner = False

if not settings.use_fast_partitioner:
partitioned_module = partitioning.global_partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
)

# Store TRT replicas of Torch subgraphs
trt_modules = {}

# Iterate over all components that can be accelerated
# Generate the corresponding TRT Module for those
for name, _ in partitioned_module.named_children():
# Criteria for a module to be convertible to TRT
if settings.use_fast_partitioner and "_run_on_acc" not in name:
continue

submodule = getattr(partitioned_module, name)

# Get submodule inputs
submodule_inputs = partitioning.get_submod_inputs(
partitioned_module, submodule, sample_inputs
)

assert submodule_inputs is not None
# Handle long/double inputs if requested by the user
if settings.truncate_long_and_double:
submodule_inputs = repair_long_or_double_inputs(
partitioned_module, submodule, submodule_inputs, name
)

# Create TRT Module from submodule
trt_mod = convert_module(
submodule,
submodule_inputs,
settings=settings,
name=name,
)

trt_modules[name] = trt_mod

# Replace all FX Modules with TRT Modules
for name, trt_mod in trt_modules.items():
setattr(partitioned_module, name, trt_mod)

# Reset settings object to user specification after fallback to global partitioning mode
if fast_partitioner_failed:
settings.use_fast_partitioner = True

return partitioned_module
Loading