You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
indices are const tensor, which gets const folded into frozen param. The meta of the frozen param node is empty dict, leading to converter validation check failure here
making torch.ops.aten._embedding_bag.default unsupported op, and compile failure.
To Reproduce
import torch
import torch_tensorrt
import tensorrt
from torch_tensorrt.dynamo._compiler import compile as dynamo_compile
print(torch.__version__) # 2.5.0+cu124
print(torch_tensorrt.__version__) # 2.5.0
print(tensorrt.__version__) # 10.3.0
class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding_bag_module = torch.nn.EmbeddingBag(100, 32, mode='sum')
self.register_buffer("index_tensor", torch.tensor([x for x in range(100)], dtype=torch.long))
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
out = self.embedding_bag_module(self.index_tensor.broadcast_to(tensor.shape), per_sample_weights=tensor)
return out
error_model_input = (torch.randn(20, 100, dtype=torch.float32), )
error_model = ToyModel()
error_model_eval = error_model.eval()
with torch.no_grad():
ep = torch.export.export(error_model_eval, args=error_model_input)
compiled = dynamo_compile(
exported_program=ep,
disable_tf32=True,
inputs=error_model_input,
min_block_size=1,
debug=True,
)
Expected behavior
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
Torch-TensorRT Version (e.g. 1.0.0): 2.5
PyTorch Version (e.g. 1.0): 2.5
CPU Architecture: x86_64
OS (e.g., Linux): Ubuntu
How you installed PyTorch (conda, pip, libtorch, source): pip
Build command you used (if compiling from source):
Are you using local sources or building from archives:
Python version: 3.11.9
CUDA version: 12.6
GPU models and configuration: Nvidia L4
Any other relevant information:
Additional context
The text was updated successfully, but these errors were encountered:
Due to some limitations, we need meta to deal with data-dependent issue. In addition, we currently only support 1D indices/input. If this doesn't work for you, I think you have to fall back this op to pytorch for now.
Bug Description
indices are const tensor, which gets const folded into frozen param. The meta of the frozen param node is empty dict, leading to converter validation check failure here
making
torch.ops.aten._embedding_bag.default
unsupported op, and compile failure.To Reproduce
Expected behavior
Environment
conda
,pip
,libtorch
, source): pipAdditional context
The text was updated successfully, but these errors were encountered: