Skip to content

🐛 [Bug] RuntimeError: linear convolution has bias of type <class 'tensorrt.tensorrt.ITensor'>, Expect Optional[Tensor] when using torch_tensorrt as backend in torch.compile #2506

Closed
@airalcorn2

Description

@airalcorn2

Bug Description

When trying to compile a simple PyTorch module with torch.compile(model, backend="torch_tensorrt"), I get:

RuntimeError: linear convolution has bias of type <class 'tensorrt.tensorrt.ITensor'>, Expect Optional[Tensor]

I don't get any errors when using torch.compile(model) or torch_tensorrt.compile(model, inputs=inputs).

To Reproduce

Steps to reproduce the behavior:

import torch
import torch_tensorrt

from torch import nn


class PointNetLayer(nn.Module):
    def __init__(self, in_feats, out_feats):
        super().__init__()
        self.linear = nn.Conv2d(in_feats, out_feats, 1)
        self.norm = nn.BatchNorm2d(out_feats)
        self.relu = nn.ReLU()

    def forward(self, points):
        pn_feats = self.relu(self.norm(self.linear(points)))
        return pn_feats


def main():
    device = torch.device("cuda:0")
    model = PointNetLayer(3, 64).to(device)
    model.eval()
    print(model)

    points = torch.rand((1, 3, 12000, 200)).to(device)
    # Works.
    with torch.no_grad():
        _ = model(points)

    model_opt = torch.compile(model)
    # Works.
    with torch.no_grad():
        _ = model_opt(points)

    torch._dynamo.reset()
    model_opt = torch.compile(model, backend="torch_tensorrt")
    # RuntimeError.
    with torch.no_grad():
        _ = model_opt(points)

    torch._dynamo.reset()
    inputs = [torch_tensorrt.Input(points.shape)]
    # Works.
    trt_ts_module = torch_tensorrt.compile(model, inputs=inputs)


if __name__ == "__main__":
    main()

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 1.4.0
  • PyTorch Version (e.g. 1.0): 2.0.1+cu117
  • CPU Architecture: i7-12800H
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.10.10
  • CUDA version: 12.2
  • GPU models and configuration: GeForce RTX 3080 Ti
  • Any other relevant information:

Additional context

Metadata

Metadata

Labels

bugSomething isn't workingcomponent: convertersIssues re: Specific op converters

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions