Skip to content

🐛 [Bug] Encountered bug when using Torch-TensorRT #2328

Closed
@balazon

Description

@balazon

Bug Description

This might not be a bug, maybe it's a feature request, not sure.
I wanted to compile torch.einsum with torch_tensorrt and I get back an error

I was reading this tutorial about compiling transformers:
https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/torch_compile_transformers_example.html
Based on this I created a small example module containing einsum, and I get this error:

torch._dynamo.exc.BackendCompilerFailed: backend='torch_tensorrt' raised:
RuntimeError: Autograd has not been implemented for operator

While executing %einsum_1 : [num_users=1] = call_function[target=torch.ops.tensorrt.einsum](args = (i,ij->i, (%l_x_, %l__self___b)), kwargs = {})

torch.einsum is either not a supported op yet, or if it is, it's buggy I think
It's not listed under supported ops here:
https://github.com/pytorch/TensorRT/blob/8ebb5991f8bc46fea6179593b882d5c160bc1a53/docs/_sources/indices/supported_ops.rst.txt

TensorRT supports it according to this: (IEinsumLayer)
https://docs.nvidia.com/deeplearning/tensorrt/archives/tensorrt-861/operators/index.html
So I don't see why it wouldn't be supported in torch-tensorrt.

I see some issues/PR-s that relate to einsum, but I don't know if they, closest issue I found is
#277
But it's closed due to inactivity
Other issues/PRs:
#1385
#1985
#1420
#1005

To Reproduce

Steps to reproduce the behavior:

  1. Run this script
import torch
import torch_tensorrt
import time

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.B = torch.nn.Parameter(torch.tensor([
            [0, 1, 2, 3],
            [4, 5, 6, 7],
            [8, 9, 10, 11]
        ], dtype=torch.float32)).cuda()
    
    def forward(self, x):
        return torch.einsum("i,ij->i", x, self.B)

def compile_my_model():
    model = MyModule().eval()

    a = torch.tensor([0, 1, 2], dtype=torch.float32).cuda().to('cuda')
    inputs = [a]


    # Enabled precision for TensorRT optimization
    enabled_precisions = {torch.float}

    # Whether to print verbose logs
    debug = True

    # Workspace size for TensorRT
    workspace_size = 8 << 30

    # Maximum number of TRT Engines
    # (Lower value allows more graph segmentation)
    min_block_size = 7

    # Operations to Run in Torch, regardless of converter support
    torch_executed_ops = {}

    # Define backend compilation keyword arguments
    compilation_kwargs = {
        "enabled_precisions": enabled_precisions,
        "debug": debug,
        "workspace_size": workspace_size,
        "min_block_size": min_block_size,
        "torch_executed_ops": torch_executed_ops,
    }

    # compile option 1
    optimized_model = torch.compile(
        model,
        backend="torch_tensorrt",
        options=compilation_kwargs,
    )
    
    # compile option 2: is it the same as option 1? still fails
    # optimized_model = torch_tensorrt.compile(model, ir="torch_compile", inputs=inputs, **compilation_kwargs)

    # compile option 3: success
    # optimized_model = torch_tensorrt.compile(model, inputs=inputs, **compilation_kwargs)

    res = optimized_model(*inputs)
    print("res:", res)

    torch._dynamo.reset()

if __name__ == "__main__":
    compile_my_model()
    print("done")

  1. run it with # compile option 2
  2. run it with # compile option 3

Only compile option 3 works, but I don't know what the difference is between any of these 3 options, can somebody clear that up? option 1 and 2 I think are the same, but option 3?

Expected behavior

I expect all 3 options to work, but only the 3rd compile option seems to work.

Environment

I'm using this docker image:
nvcr.io/nvidia/pytorch:23.08-py3
https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-23-08.html

  • Torch-TensorRT Version: 2.0.0.dev0
  • PyTorch Version: 2.1.0a0+29c30b1
  • Python version: 3.10

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingcomponent: dynamoIssues relating to the `torch.compile` or `torch._dynamo.export` paths

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions