Description
Bug Description
This might not be a bug, maybe it's a feature request, not sure.
I wanted to compile torch.einsum with torch_tensorrt and I get back an error
I was reading this tutorial about compiling transformers:
https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/torch_compile_transformers_example.html
Based on this I created a small example module containing einsum, and I get this error:
torch._dynamo.exc.BackendCompilerFailed: backend='torch_tensorrt' raised:
RuntimeError: Autograd has not been implemented for operator
While executing %einsum_1 : [num_users=1] = call_function[target=torch.ops.tensorrt.einsum](args = (i,ij->i, (%l_x_, %l__self___b)), kwargs = {})
torch.einsum is either not a supported op yet, or if it is, it's buggy I think
It's not listed under supported ops here:
https://github.com/pytorch/TensorRT/blob/8ebb5991f8bc46fea6179593b882d5c160bc1a53/docs/_sources/indices/supported_ops.rst.txt
TensorRT supports it according to this: (IEinsumLayer)
https://docs.nvidia.com/deeplearning/tensorrt/archives/tensorrt-861/operators/index.html
So I don't see why it wouldn't be supported in torch-tensorrt.
I see some issues/PR-s that relate to einsum, but I don't know if they, closest issue I found is
#277
But it's closed due to inactivity
Other issues/PRs:
#1385
#1985
#1420
#1005
To Reproduce
Steps to reproduce the behavior:
- Run this script
import torch
import torch_tensorrt
import time
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.B = torch.nn.Parameter(torch.tensor([
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]
], dtype=torch.float32)).cuda()
def forward(self, x):
return torch.einsum("i,ij->i", x, self.B)
def compile_my_model():
model = MyModule().eval()
a = torch.tensor([0, 1, 2], dtype=torch.float32).cuda().to('cuda')
inputs = [a]
# Enabled precision for TensorRT optimization
enabled_precisions = {torch.float}
# Whether to print verbose logs
debug = True
# Workspace size for TensorRT
workspace_size = 8 << 30
# Maximum number of TRT Engines
# (Lower value allows more graph segmentation)
min_block_size = 7
# Operations to Run in Torch, regardless of converter support
torch_executed_ops = {}
# Define backend compilation keyword arguments
compilation_kwargs = {
"enabled_precisions": enabled_precisions,
"debug": debug,
"workspace_size": workspace_size,
"min_block_size": min_block_size,
"torch_executed_ops": torch_executed_ops,
}
# compile option 1
optimized_model = torch.compile(
model,
backend="torch_tensorrt",
options=compilation_kwargs,
)
# compile option 2: is it the same as option 1? still fails
# optimized_model = torch_tensorrt.compile(model, ir="torch_compile", inputs=inputs, **compilation_kwargs)
# compile option 3: success
# optimized_model = torch_tensorrt.compile(model, inputs=inputs, **compilation_kwargs)
res = optimized_model(*inputs)
print("res:", res)
torch._dynamo.reset()
if __name__ == "__main__":
compile_my_model()
print("done")
- run it with # compile option 2
- run it with # compile option 3
Only compile option 3 works, but I don't know what the difference is between any of these 3 options, can somebody clear that up? option 1 and 2 I think are the same, but option 3?
Expected behavior
I expect all 3 options to work, but only the 3rd compile option seems to work.
Environment
I'm using this docker image:
nvcr.io/nvidia/pytorch:23.08-py3
https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-23-08.html
- Torch-TensorRT Version: 2.0.0.dev0
- PyTorch Version: 2.1.0a0+29c30b1
- Python version: 3.10