Open
Description
Bug Description
indices are const tensor, which gets const folded into frozen param. The meta of the frozen param node is empty dict, leading to converter validation check failure here
making torch.ops.aten._embedding_bag.default
unsupported op, and compile failure.
To Reproduce
import torch
import torch_tensorrt
import tensorrt
from torch_tensorrt.dynamo._compiler import compile as dynamo_compile
print(torch.__version__) # 2.5.0+cu124
print(torch_tensorrt.__version__) # 2.5.0
print(tensorrt.__version__) # 10.3.0
class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding_bag_module = torch.nn.EmbeddingBag(100, 32, mode='sum')
self.register_buffer("index_tensor", torch.tensor([x for x in range(100)], dtype=torch.long))
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
out = self.embedding_bag_module(self.index_tensor.broadcast_to(tensor.shape), per_sample_weights=tensor)
return out
error_model_input = (torch.randn(20, 100, dtype=torch.float32), )
error_model = ToyModel()
error_model_eval = error_model.eval()
with torch.no_grad():
ep = torch.export.export(error_model_eval, args=error_model_input)
compiled = dynamo_compile(
exported_program=ep,
disable_tf32=True,
inputs=error_model_input,
min_block_size=1,
debug=True,
)
Expected behavior
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 2.5
- PyTorch Version (e.g. 1.0): 2.5
- CPU Architecture: x86_64
- OS (e.g., Linux): Ubuntu
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version: 3.11.9
- CUDA version: 12.6
- GPU models and configuration: Nvidia L4
- Any other relevant information: