Skip to content

🐛 [Bug] Support for modules with multiple outputs seems broken in v1.2.0 #1368

Closed
@cpjenkins

Description

@cpjenkins

Bug Description

It appears that modules with multiple outputs no longer compile when using dynamic input shapes in v1.2.0.

The following example works in v1.1.1 but fails in v1.2.0

import torch
import torch.nn as nn
import torch_tensorrt as trt

from torch import Tensor
from typing import List, Tuple

trt.logging.set_reportable_log_level(trt.logging.Level.Debug)

class Net(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)

        self.h = nn.Conv2d(1, 4, 3, padding=1)
        self.g = nn.Conv2d(1, 4, 3, padding=1)

    def forward(self, x) -> Tuple[Tensor, Tensor]:
        return self.h(x), self.g(x)

model = Net().eval()
model = torch.jit.trace(model, torch.randn(1, 1, 128, 128))
model = trt.compile(
    model.cuda(),
    inputs=[
        trt.Input(min_shape=(1, 1, 128, 128),
                  opt_shape=(4, 1, 256, 256),
                  max_shape=(8, 1, 512, 512))
    ],
    min_block_size=1
    require_full_compilation=True
)

Fails with error:

RuntimeError: [Error thrown at core/conversion/conversion.cpp:230] Tuple type. Only a single tensor or a TensorList type is supported.

In v1.1.1, the graph returns two output tensors - while in v1.2.0 it creates an intermediate node to (%13) and returns a single TupleConstruct output. Unfortunately MarkOutputs in core/conversion/converter.cpp now only gets a single tuple output and throws an error.

Graphs are given below:

v1.1.1

  %11 : Tensor = aten::_convolution(%x, %self.h.weight, %self.h.bias, %3, %3, %3, %5, %2, %4, %5, %5, %6, %6), scope: __module.h # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  %12 : Tensor = aten::_convolution(%x, %self.g.weight, %self.g.bias, %3, %3, %3, %5, %2, %4, %5, %5, %6, %6), scope: __module.g # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  return (%11, %12)

v1.2.0

  %11 : Tensor = aten::_convolution(%x, %self.h.weight, %self.h.bias, %3, %3, %3, %5, %2, %4, %5, %5, %6, %6), scope: __module.h # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  %12 : Tensor = aten::_convolution(%x, %self.g.weight, %self.g.bias, %3, %3, %3, %5, %2, %4, %5, %5, %6, %6), scope: __module.g # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  %13 : (Float(1, 4, 128, 128, strides=[65536, 16384, 128, 1], requires_grad=0, device=cpu), Float(1, 4, 128, 128, strides=[65536, 16384, 128, 1], requires_grad=0, device=cpu)) = prim::TupleConstruct(%11, %12)
  return (%13)

Expected behavior

A return type of Tuple[Tensor, Tensor] should be treated as two separate outputs - not one.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version: v1.1.1 and v1.2.0
  • PyTorch Version (e.g. 1.0): 1.12.1
  • CPU Architecture: AMD x86_64
  • OS (e.g., Linux): Linux / Ubuntu 22.0.4
  • How you installed PyTorch: source (w/ c++11 abi)
  • Build command you used (if compiling from source): n/a
  • Are you using local sources or building from archives: local source
  • Python version: 3.8
  • CUDA version: 11.7
  • GPU models and configuration: RTX A6000
  • Any other relevant information:

Additional context

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions