Skip to content

🐛 [Bug] Const indices failed with embedding bag #3263

Open
@sean-xiang-applovin

Description

@sean-xiang-applovin

Bug Description

indices are const tensor, which gets const folded into frozen param. The meta of the frozen param node is empty dict, leading to converter validation check failure here

making torch.ops.aten._embedding_bag.default unsupported op, and compile failure.

To Reproduce

import torch
import torch_tensorrt
import tensorrt
from torch_tensorrt.dynamo._compiler import compile as dynamo_compile
print(torch.__version__) # 2.5.0+cu124
print(torch_tensorrt.__version__) # 2.5.0
print(tensorrt.__version__) # 10.3.0

class ToyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.embedding_bag_module = torch.nn.EmbeddingBag(100, 32, mode='sum')
        self.register_buffer("index_tensor", torch.tensor([x for x in range(100)], dtype=torch.long))

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        out = self.embedding_bag_module(self.index_tensor.broadcast_to(tensor.shape), per_sample_weights=tensor)
        return out

error_model_input = (torch.randn(20, 100, dtype=torch.float32), )
error_model = ToyModel()
error_model_eval = error_model.eval()
with torch.no_grad():
    ep = torch.export.export(error_model_eval, args=error_model_input)
    compiled = dynamo_compile(
        exported_program=ep,
        disable_tf32=True,
        inputs=error_model_input,
        min_block_size=1,
        debug=True,
    )

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.5
  • PyTorch Version (e.g. 1.0): 2.5
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Ubuntu
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.11.9
  • CUDA version: 12.6
  • GPU models and configuration: Nvidia L4
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions