Open
Description
Bug Description
When compiling Bert, a device mismatch occurs. This seems to be caused by weights moved to CPU during compilation.
To Reproduce
Steps to reproduce the behavior:
Run this script:
import torch
import torch_tensorrt as torchtrt
from transformers import BertModel
inputs = [
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
]
model = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda")
enabled_precisions = {torch.float}
debug = True
min_block_size = 1
use_python_runtime = False
exp_program = torch.export.export(model, tuple(inputs))
trt_gm = torchtrt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
immutable_weights=False,
)
Expected behavior
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): main branch
- PyTorch Version (e.g. 1.0): nightly
- OS (e.g., Linux): LInux
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip