Skip to content

🐛 [Bug] Part of the weights are placed to CPU during compilation #3450

Open
@cehongwang

Description

@cehongwang

Bug Description

When compiling Bert, a device mismatch occurs. This seems to be caused by weights moved to CPU during compilation.

To Reproduce

Steps to reproduce the behavior:

Run this script:

import torch
import torch_tensorrt as torchtrt

from transformers import BertModel

inputs = [
        torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
    ]
model = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda")
enabled_precisions = {torch.float}
debug = True
min_block_size = 1
use_python_runtime = False

exp_program = torch.export.export(model, tuple(inputs))

trt_gm = torchtrt.dynamo.compile(
    exp_program,
    tuple(inputs),
    use_python_runtime=use_python_runtime,
    enabled_precisions=enabled_precisions,
    debug=debug,
    min_block_size=min_block_size,
    immutable_weights=False,
)


Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): main branch
  • PyTorch Version (e.g. 1.0): nightly
  • OS (e.g., Linux): LInux
  • How you installed PyTorch (conda, pip, libtorch, source): pip

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions