Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Const indices failed with embedding bag #3263

Open
sean-xiang-applovin opened this issue Oct 24, 2024 · 1 comment
Open

🐛 [Bug] Const indices failed with embedding bag #3263

sean-xiang-applovin opened this issue Oct 24, 2024 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@sean-xiang-applovin
Copy link

Bug Description

indices are const tensor, which gets const folded into frozen param. The meta of the frozen param node is empty dict, leading to converter validation check failure here

making torch.ops.aten._embedding_bag.default unsupported op, and compile failure.

To Reproduce

import torch
import torch_tensorrt
import tensorrt
from torch_tensorrt.dynamo._compiler import compile as dynamo_compile
print(torch.__version__) # 2.5.0+cu124
print(torch_tensorrt.__version__) # 2.5.0
print(tensorrt.__version__) # 10.3.0

class ToyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.embedding_bag_module = torch.nn.EmbeddingBag(100, 32, mode='sum')
        self.register_buffer("index_tensor", torch.tensor([x for x in range(100)], dtype=torch.long))

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        out = self.embedding_bag_module(self.index_tensor.broadcast_to(tensor.shape), per_sample_weights=tensor)
        return out

error_model_input = (torch.randn(20, 100, dtype=torch.float32), )
error_model = ToyModel()
error_model_eval = error_model.eval()
with torch.no_grad():
    ep = torch.export.export(error_model_eval, args=error_model_input)
    compiled = dynamo_compile(
        exported_program=ep,
        disable_tf32=True,
        inputs=error_model_input,
        min_block_size=1,
        debug=True,
    )

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.5
  • PyTorch Version (e.g. 1.0): 2.5
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Ubuntu
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.11.9
  • CUDA version: 12.6
  • GPU models and configuration: Nvidia L4
  • Any other relevant information:

Additional context

@zewenli98
Copy link
Collaborator

Due to some limitations, we need meta to deal with data-dependent issue. In addition, we currently only support 1D indices/input. If this doesn't work for you, I think you have to fall back this op to pytorch for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants