Skip to content

fix: Address .numpy() issue on fake tensors #1949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def compile(
torch_executed_modules=[],
**kwargs,
):
if debug:
logger.setLevel(logging.DEBUG)

logger.warn(
"The Dynamo backend is an experimental feature, for which only the "
Expand Down
5 changes: 2 additions & 3 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def aot_torch_tensorrt_aten_backend(
)


@fake_tensor_unsupported
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the ultimate backend everything goes through right? so does that mean we cant work on fake tensors? is this different than symbolic shapes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also could we just turn a "fake_tensor" into an "ITensor" immediately. sounds like they are similar

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is the backend everything goes through, and my understanding is that FakeTensors are for use at compile time, and are distinct from SymInts which are the symbolic shape representations.

The challenge with fake tensors right now is that any tensors instantiated during the call are fake, which means that the constant tensors which we need to provide to TensorRT, as in the code snippet below, are "fake" and thus contain no value to be parsed. I think a long-term solution could be to support Fake Tensors fully, but this solution temporarily resolves TRT/Torch compatibility issues.

if isinstance(value, int):
value = torch.IntTensor([value])
if isinstance(value, float):
value = torch.Tensor([value])
if dtype:
value = value.to(dtype)
constant = network.add_constant(value.shape, to_numpy(value))

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instantiated #1951 with a feature proposal and additional discussion

def _pretraced_backend(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
Expand Down Expand Up @@ -120,9 +121,7 @@ def _compile_module(
trt_mod = convert_module(
submodule,
submodule_inputs,
debug=settings.debug,
workspace_size=settings.workspace_size,
precision=settings.precision,
settings=settings,
)

# Replace FX Module with TRT Module
Expand Down
18 changes: 7 additions & 11 deletions py/torch_tensorrt/dynamo/backend/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,41 @@
import torch
from torch_tensorrt.fx.trt_module import TRTModule
from torch_tensorrt import TRTModuleNext
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
from torch_tensorrt.fx.fx2trt import (
InputTensorSpec,
TRTInterpreter,
)
from torch_tensorrt.fx.utils import LowerPrecision

import tensorrt as trt


def convert_module(
module: torch.fx.GraphModule,
inputs: Sequence[torch.Tensor],
debug: bool = False,
workspace_size: int = 20 << 30,
precision: LowerPrecision = LowerPrecision.FP32,
settings: CompilationSettings = CompilationSettings(),
) -> Union[TRTModuleNext, TRTModule]:
"""Convert an FX module to a TRT module
Args:
module: FX GraphModule to convert
inputs: Sequence of Tensors representing inputs to the module
debug: Whether to print out verbose debugging information
workspace_size: Maximum workspace TRT is allowed to use for the module
precision: Model Layer precision
settings: Compilation settings
Returns:
TRTModule or TRTModuleNext
"""
interp = TRTInterpreter(
module,
InputTensorSpec.from_tensors(inputs),
explicit_batch_dimension=True,
logger_level=(trt.Logger.VERBOSE if debug else trt.Logger.WARNING),
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
)

r = interp.run(
max_workspace_size=workspace_size,
lower_precision=precision,
max_workspace_size=settings.workspace_size,
lower_precision=settings.precision,
profiling_verbosity=(
trt.ProfilingVerbosity.VERBOSE
if debug
if settings.debug
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
),
)
Expand Down
13 changes: 8 additions & 5 deletions py/torch_tensorrt/dynamo/backend/lowering/_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,18 @@ def print_support_overview(self, num_trt_blocks: Optional[int] = None):
f"\nNumber of TensorRT-Accelerated Engines Generated: {num_trt_blocks}"
)

logger.debug("\nSupported Nodes:")
# Reformat support messages for debugger to print node overview as a single string
supported_nodes_str = "\nSupported Nodes:\n"
for node_name in self.supported_operators:
logger.debug("-", node_name)
supported_nodes_str += f"- {node_name}\n"

logger.debug(supported_nodes_str)

if len(self.unsupported_operators) != 0:
logger.debug("\nUnsupported or Excluded Nodes:")
unsupported_nodes_str = "\nUnsupported or Excluded Nodes:\n"
for node_name in self.unsupported_operators:
logger.debug("-", node_name)
logger.debug("\n")
unsupported_nodes_str += f"- {node_name}\n"
logger.debug(unsupported_nodes_str)
else:
logger.debug("\nAll Nodes Supported\n")

Expand Down