Skip to content

🐛 [Bug] Wrong output shape from cross compiled exported program #3400

Closed
@HolyWu

Description

@HolyWu

Bug Description

The first dim of the output disappeared when perform inference with the cross compiled exported program on Windows.

To Reproduce

Run on Linux

from __future__ import annotations

import os

import torch
import torch_tensorrt

os.environ["CI_BUILD"] = "1"

class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + x


with torch.inference_mode():
    model = MyModule().eval().cuda()
    inputs = (torch.zeros(2, 4, 6, 8, dtype=torch.float, device="cuda"),)
    exported_program = torch.export.export(model, inputs)

    trt_model = torch_tensorrt.dynamo.cross_compile_for_windows(
        exported_program,
        inputs,
        enabled_precisions={torch.float},
        debug=True,
        min_block_size=1,
    )
    torch_tensorrt.dynamo.save_cross_compiled_exported_program(trt_model, "trt_windows.ep")
TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops
[02/16/2025-22:50:47] [TRT] [W] Functionality provided through tensorrt.plugin module is experimental.
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %x), kwargs = {})
    return (add,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %x), kwargs = {})
    return (add,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %x), kwargs = {})
    return (add,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %x), kwargs = {})
    return (add,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %x), kwargs = {})
    return (add,)
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.add.Tensor: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.add.Tensor
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.add.Tensor + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.add.Tensor: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.add.Tensor
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.add.Tensor + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Updated metadata for node: _run_on_acc_0 with its corresponding submodule outputs
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
 Input shapes: [(2, 4, 6, 8)]
 graph():
    %x : [num_users=1] = placeholder[target=x]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %x), kwargs = {})
    return add
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.add.Tensor: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.add.Tensor
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[2, 4, 6, 8], dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (2, 4, 6, 8)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /add (kind: aten.add.Tensor, args: ('x <Node>', 'x <Node>'))
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.add.Tensor: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.add.Tensor
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /add [aten.add.Tensor] (Inputs: (x: (2, 4, 6, 8)@torch.float32, x: (2, 4, 6, 8)@torch.float32) | Outputs: (add: (2, 4, 6, 8)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('add <Node>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(2, 4, 6, 8), dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (add: (2, 4, 6, 8)@torch.float32) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.006162
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Setting runtime_platform as trt.RuntimePlatform.WINDOWS_AMD64
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.195916
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 16356 bytes of Memory
DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 1 Total Operators, of which 1 operators are supported, 100.0% coverage

Compiled with: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/torch_tensorrt_engine_cache/timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False, use_explicit_typing=False, use_fp32_acc=False, refit_identical_engine_weights=False, strip_engine_weights=False, immutable_weights=True, enable_weight_streaming=False, enable_cross_compile_for_windows=True, use_aot_joint_export=True)

  Graph Structure:

   Inputs: List[Tensor: (2, 4, 6, 8)@float32]
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: (2, 4, 6, 8)@float32]
     Number of Operators in Engine: 1
     Engine Outputs: List[Tensor: (2, 4, 6, 8)@float32]
    ...
   Outputs: List[Tensor: (2, 4, 6, 8)@float32]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 1.0
   Most Operators in a TRT Engine: 1

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=1 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=1 which generates 1 TRT engine(s)
DEBUG:torch_tensorrt.dynamo._compiler:successfully saved the module for windows at trt_windows.ep

Run on Windows

from __future__ import annotations

import os

import torch
import torch_tensorrt


with torch.inference_mode():
    model = torch_tensorrt.dynamo.load_cross_compiled_exported_program("trt_windows.ep").module()
    inputs = (torch.randn(2, 4, 6, 8, dtype=torch.float, device="cuda"),)
    output = model(*inputs)
    print(f"{output.shape=}")
TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops
[02/16/2025-22:51:54] [TRT] [W] Functionality provided through tensorrt.plugin module is experimental.
output.shape=torch.Size([4, 6, 8])

Expected behavior

output.shape should be [2, 4, 6, 8].

Environment

  • Torch-TensorRT Version (e.g. 1.0.0): 2.6.0+cu126
  • PyTorch Version (e.g. 1.0): 2.6.0+cu126
  • CPU Architecture: x64
  • OS (e.g., Linux): Ubuntu 24.04 and Windows 11
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.12.3
  • CUDA version: 12.6
  • GPU models and configuration: RTX 4060 Ti
  • Any other relevant information:

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions