Skip to content

🐛 [Bug] 'aten::unbind' compile failed when with two dynamic dimension in it's shape #2331

Closed
@xylcbd

Description

@xylcbd

Bug Description

'aten::unbind' compile failed when with two dynamic dimension in it's shape

To Reproduce

  • code:
# coding: utf-8
import torch
import torch.nn as nn
import torch_tensorrt

class MyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x):
        # batch_size * 3 * sec_dim * in_dim
        assert len(x.shape) == 4
        x1, x2, x3 = x.unbind(1)
        y = torch.cat([x1, x2, x3], dim = 0)
        return y

def convert_to_trt(pytorch_model, dummy_input):
    max_batch_size, _, max_sec_dim, in_dim = dummy_input[0].shape

    with torch_tensorrt.logging.debug():
        with torch.jit.optimized_execution(False):
            _jit_model = torch.jit.trace(pytorch_model, dummy_input)
            
        trt_model = torch_tensorrt.compile(
            _jit_model, 
            inputs = [
                # param1
                torch_tensorrt.Input(min_shape=[1, 3, 1, in_dim], opt_shape=[max_batch_size//2, 3, max_sec_dim//2, in_dim], max_shape=[max_batch_size, 3, max_sec_dim, in_dim])
            ],
            torch_executed_ops = [
                # 'aten::unbind'
            ],
            torch_executed_modules = [
                
            ],
            min_block_size=1,
            enabled_precisions = {torch.float32},
            truncate_long_and_double = True,
            allow_shape_tensors = True,
            num_avg_timing_iters = 1,
            workspace_size = 1 << 32
        )

        return trt_model

def main():
    device = 'cuda'

    # load model
    model = MyModel()
    model = model.to(device)
    model.eval()

    max_batch_size = 16
    max_sec_dim = 16
    in_dim = 64
    dummy_input = [
        torch.rand(max_batch_size, 3, max_sec_dim, in_dim).float().to(device)
    ]
    
    print(model(dummy_input[0]).shape)
    trt_model = convert_to_trt(model, dummy_input)
    print(trt_model(dummy_input[0]))

with torch.no_grad():
    main()
  • log:
WARNING:torch_tensorrt._compile:Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to suppress the warning. Compiling the module with ir=torchscript
DEBUG: [Torch-TensorRT] - TensorRT Compile Spec: {
    "Inputs": [
Input(min_shape=(1,3,1,64,), opt_shape=(8,3,8,64,), max_shape=(16,3,16,64,), dtype=Unknown data type, format=Contiguous/Linear/NCHW, tensor_domain=[0, 2))    ]
    "Enabled Precision": [Float, ]
    "TF32 Disabled": 0
    "Sparsity": 0
    "Refit": 0
    "Debug": 0
    "Device":  {
        "device_type": GPU
        "allow_gpu_fallback": False
        "gpu_id": 0
        "dla_core": -1
    }

    "Engine Capability": Default
    "Num Avg Timing Iters": 1
    "Workspace Size": 4294967296
    "DLA SRAM Size": 1048576
    "DLA Local DRAM Size": 1073741824
    "DLA Global DRAM Size": 536870912
    "Truncate long and double": 1
    "Allow Shape tensors": 1
    "Torch Fallback":  {
        "enabled": True
        "min_block_size": 1
        "forced_fallback_operators": [
        ]
        "forced_fallback_modules": [
        ]
    }
}
DEBUG: [Torch-TensorRT] - init_compile_spec with input vector
DEBUG: [Torch-TensorRT] - Settings requested for Lowering:
    torch_executed_modules: [
    ]
DEBUG: [Torch-TensorRT] - RemoveNOPs - Note: Removing operators that have no meaning in TRT
INFO: [Torch-TensorRT] - Lowered Graph: graph(%x : Tensor):
  %3 : int = prim::Constant[value=1]() # xxx.py:13:0
  %2 : int = prim::Constant[value=0]() # xxx.py:14:0
  %4 : Tensor[] = aten::unbind(%x, %3) # xxx.py:13:0
  %x1 : Tensor, %x2 : Tensor, %x3 : Tensor = prim::ListUnpack(%4)
  %8 : Tensor[] = prim::ListConstruct(%x1, %x2, %x3)
  %9 : Tensor = aten::cat(%8, %2) # xxx.py:14:0
  return (%9)

DEBUG: [Torch-TensorRT] - Unable to get schema for Node %3 : int = prim::Constant[value=1]() # xxx.py:13:0 (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %2 : int = prim::Constant[value=0]() # xxx.py:14:0 (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %x1 : Tensor, %x2 : Tensor, %x3 : Tensor = prim::ListUnpack(%4) (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %8 : Tensor[] = prim::ListConstruct(%x1, %x2, %x3) (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Found 1 inputs to graph
DEBUG: [Torch-TensorRT] - Handle input of debug name: x
DEBUG: [Torch-TensorRT] - Paring 0: x : Input(shape: [-1, 3, -1, 64], min: [1, 3, 1, 64], opt: [8, 3, 8, 64], max: [16, 3, 16, 64], dtype: Float, format: NCHW\Contiguous\Linear)
DEBUG: [Torch-TensorRT] - Found 1 inputs to graph
DEBUG: [Torch-TensorRT] - Handle input of debug name: x
DEBUG: [Torch-TensorRT] - In MapInputsAndDetermineDTypes, the g->inputs() size is 1, CollectionInputSpecMap size is1
WARNING: [Torch-TensorRT] - Cannot infer input type from calcuations in graph for input x. Assuming it is Float32. If not, specify input type explicity
INFO: [Torch-TensorRT] - Skipping partitioning since model is fully supported
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %3 : int = prim::Constant[value=1]() # xxx.py:13:0 (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %2 : int = prim::Constant[value=0]() # xxx.py:14:0 (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %x1 : Tensor, %x2 : Tensor, %x3 : Tensor = prim::ListUnpack(%4) (NodeConverterRegistry.Convertable)
DEBUG: [Torch-TensorRT] - Unable to get schema for Node %8 : Tensor[] = prim::ListConstruct(%x1, %x2, %x3) (NodeConverterRegistry.Convertable)
INFO: [Torch-TensorRT TorchScript Conversion Context] - [MemUsageChange] Init CUDA: CPU +1, GPU +0, now: CPU 226, GPU 515 (MiB)
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Trying to load shared library libnvinfer_builder_resource.so.8.6.1
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Loaded shared library libnvinfer_builder_resource.so.8.6.1
INFO: [Torch-TensorRT TorchScript Conversion Context] - [MemUsageChange] Init builder kernel library: CPU +1444, GPU +266, now: CPU 1747, GPU 781 (MiB)
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - CUDA lazy loading is enabled.
INFO: [Torch-TensorRT] - Settings requested for TensorRT engine:
    Enabled Precisions: Float32 
    TF32 Floating Point Computation Enabled: 1
    Truncate Long and Double: 1
    Make Refittable Engine: 0
    Debuggable Engine: 0
    GPU ID: 0
    Allow GPU Fallback (if running on DLA): 0
    Avg Timing Iterations: 1
    Max Workspace Size: 4294967296
    DLA SRAM Size: 1048576
    DLA Local DRAM Size: 1073741824
    DLA Global DRAM Size: 536870912
    Device Type: GPU
    GPU ID: 0
    Engine Capability: standard
    Calibrator Created: 0
INFO: [Torch-TensorRT TorchScript Conversion Context] - Converting Block
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - graph(%x : Tensor):
  %3 : int = prim::Constant[value=1]() # xxx.py:13:0
  %2 : int = prim::Constant[value=0]() # xxx.py:14:0
  %4 : Tensor[] = aten::unbind(%x, %3) # xxx.py:13:0
  %x1 : Tensor, %x2 : Tensor, %x3 : Tensor = prim::ListUnpack(%4)
  %8 : Tensor[] = prim::ListConstruct(%x1, %x2, %x3)
  %9 : Tensor = aten::cat(%8, %2) # xxx.py:14:0
  return (%9)

DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Input Dimension Specs: {
}
INFO: [Torch-TensorRT TorchScript Conversion Context] - Adding Input x (named: input_0): Input(shape: [-1, 3, -1, 64], min: [1, 3, 1, 64], opt: [8, 3, 8, 64], max: [16, 3, 16, 64], dtype: Float, format: NCHW\Contiguous\Linear) in engine (conversion.AddInputs)
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Evaluating %3 : int = prim::Constant[value=1]() # xxx.py:13:0
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Found the value to be: 1
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Evaluating %2 : int = prim::Constant[value=0]() # xxx.py:14:0
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Found the value to be: 0
INFO: [Torch-TensorRT TorchScript Conversion Context] - Adding Layer %4 : Tensor[] = aten::unbind(%x, %3) # xxx.py:13:0 (ctx.AddLayer)
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Node input is an already converted tensor
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Node input is a result of a previously evaluated value
DEBUG: [Torch-TensorRT] - ITensor name: input_0
DEBUG: [Torch-TensorRT] - ITensor shape: [-1, 3, -1, 64]
DEBUG: [Torch-TensorRT] - ITensor type: Float32
DEBUG: [Torch-TensorRT] - Number of split outputs: 3
DEBUG: [Torch-TensorRT] - Weights: [1]
    Data Type: Int32
    Number of input maps: 1
    Number of output maps: 1
    Element shape: [1]
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Freezing tensor 0x69a48158 as an IConstantLayer
DEBUG: [Torch-TensorRT] - Weights: [1]
    Data Type: Int32
    Number of input maps: 1
    Number of output maps: 1
    Element shape: [1]
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Freezing tensor 0x69a493b8 as an IConstantLayer
DEBUG: [Torch-TensorRT] - Weights: [1]
    Data Type: Int32
    Number of input maps: 1
    Number of output maps: 1
    Element shape: [1]
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Freezing tensor 0x69a4a7a8 as an IConstantLayer
DEBUG: [Torch-TensorRT] - Converted split op into a list of IValues
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Evaluating %x1 : Tensor, %x2 : Tensor, %x3 : Tensor = prim::ListUnpack(%4)
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Found the evaluated value(s) to be an ITensor of shape: [-1, 1, 64]
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Found the evaluated value(s) to be an ITensor of shape: [-1, 1, 64]
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Found the evaluated value(s) to be an ITensor of shape: [-1, 1, 64]
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Evaluating %8 : Tensor[] = prim::ListConstruct(%x1, %x2, %x3)
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Found the value to be: [<__torch__.torch.classes._torch_tensorrt_eval_ivalue_types.TensorContainer object at 0x69a4c0b0>, <__torch__.torch.classes._torch_tensorrt_eval_ivalue_types.TensorContainer object at 0x69a4c130>, <__torch__.torch.classes._torch_tensorrt_eval_ivalue_types.TensorContainer object at 0x69a4c1e0>]
INFO: [Torch-TensorRT TorchScript Conversion Context] - Adding Layer %9 : Tensor = aten::cat(%8, %2) # xxx.py:14:0 (ctx.AddLayer)
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Node input is a result of a previously evaluated value
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Node input is a result of a previously evaluated value
DEBUG: [Torch-TensorRT] - Output tensor shape: [-1, 1, 64]
INFO: [Torch-TensorRT TorchScript Conversion Context] - Marking Output 9 named output_0 in engine (ctx.MarkOutput)
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Original: 10 layers
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - After dead-layer removal: 10 layers
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Graph construction completed in 0.000453138 seconds.
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - After Myelin optimization: 1 layers
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Applying ScaleNodes fusions.
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - After scale fusion: 1 layers
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - After dupe layer removal: 1 layers
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - After final dead-layer removal: 1 layers
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - After tensor merging: 1 layers
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - After vertical fusions: 1 layers
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - After dupe layer removal: 1 layers
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - After final dead-layer removal: 1 layers
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - After tensor merging: 1 layers
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - After slice removal: 1 layers
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - After concat removal: 1 layers
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Trying to split Reshape and strided tensor
INFO: [Torch-TensorRT TorchScript Conversion Context] - Graph optimization time: 0.000430204 seconds.
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Building graph using backend strategy 2
INFO: [Torch-TensorRT TorchScript Conversion Context] - Local timing cache in use. Profiling results in this builder pass will not be stored.
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Constructing optimization profile number 0 [1/1].
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Applying generic optimizations to the graph for inference.
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: kOPT values for profile 0 violate shape constraints: IShuffleLayer (Unnamed Layer* 2) [Shuffle]: reshaping failed for tensor: (Unnamed Layer* 1) [Gather]_output reshape would change volume 4096 to 512
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [shapeCompiler.cpp::evaluateShapeChecks::1276] Error Code 4: Internal Error (kOPT values for profile 0 violate shape constraints: IShuffleLayer (Unnamed Layer* 2) [Shuffle]: reshaping failed for tensor: (Unnamed Layer* 1) [Gather]_output reshape would change volume 4096 to 512)
torch.Size([48, 16, 64])
Traceback (most recent call last):
  File "xxx.py", line 66, in <module>
    main()
  File "xxx.py", line 62, in main
    trt_model = convert_to_trt(model, dummy_input)
  File "xxx.py", line 24, in convert_to_trt
    trt_model = torch_tensorrt.compile(
  File "/home/bm/miniconda3/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 185, in compile
    compiled_ts_module: torch.jit.ScriptModule = torchscript_compile(
  File "/home/bm/miniconda3/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 151, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/conversion/conversionctx/ConversionCtx.cpp:169] Building serialized network failed in TensorRT 

Expected behavior

success compiled.

Environment

Collecting environment information...
PyTorch version: 2.2.0.dev20230919+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.27.5
Libc version: glibc-2.31

Python version: 3.8.16 (default, Jan 17 2023, 23:13:24)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-83-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA GeForce RTX 4090
GPU 1: NVIDIA GeForce RTX 4090

Nvidia driver version: 535.86.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
架构:                              x86_64
CPU 运行模式:                      32-bit, 64-bit
字节序:                            Little Endian
Address sizes:                      43 bits physical, 48 bits virtual
CPU:                                48
在线 CPU 列表:                     0-47
每个核的线程数:                    2
每个座的核数:                      24
座:                                1
NUMA 节点:                         1
厂商 ID:                           AuthenticAMD
CPU 系列:                          23
型号:                              49
型号名称:                          AMD Ryzen Threadripper 3960X 24-Core Processor
步进:                              0
Frequency boost:                    enabled
CPU MHz:                           2200.000
CPU 最大 MHz:                      3800.0000
CPU 最小 MHz:                      2200.0000
BogoMIPS:                          7585.95
虚拟化:                            AMD-V
L1d 缓存:                          768 KiB
L1i 缓存:                          768 KiB
L2 缓存:                           12 MiB
L3 缓存:                           128 MiB
NUMA 节点0 CPU:                    0-47
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
标记:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es

Versions of relevant libraries:
[pip3] msgpack-numpy==0.4.8
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.5
[pip3] pytorch-lightning==1.8.6
[pip3] pytorch-triton==2.1.0+6e4932cda8
[pip3] torch==2.2.0.dev20230919+cu121
[pip3] torch-tensorrt==2.2.0.dev20230919+cu121
[pip3] torchaudio==2.2.0.dev20230919+cu121
[pip3] torchmetrics==1.1.2
[pip3] torchvision==0.17.0.dev20230919+cu121
[pip3] triton==2.0.0
[conda] msgpack-numpy             0.4.8                    pypi_0    pypi
[conda] numpy                     1.23.5                   pypi_0    pypi
[conda] pytorch-lightning         1.8.6                    pypi_0    pypi
[conda] pytorch-triton            2.1.0+6e4932cda8          pypi_0    pypi
[conda] torch                     2.2.0.dev20230919+cu121          pypi_0    pypi
[conda] torch-tensorrt            2.2.0.dev20230919+cu121          pypi_0    pypi
[conda] torchaudio                2.2.0.dev20230919+cu121          pypi_0    pypi
[conda] torchmetrics              1.1.2                    pypi_0    pypi
[conda] torchvision               0.17.0.dev20230919+cu121          pypi_0    pypi
[conda] triton                    2.0.0                    pypi_0    pypi

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions