Skip to content

[Bug Report] Gemma model tensors initialized on CPU instead of GPU during state dictionary conversion #904

@joaoncardoso

Description

@joaoncardoso

If you are submitting a bug report, please fill in the following details and use the tag [bug].

Describe the bug
When passing an existing AutoModelForCausalLM instance of Gemma2-2b that's on a GPU device to the TransformerLens.from_pretrained function, an error arises in fold_value_biases because the model biases in the state_dict are on the cpu.

Code example

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from transformer_lens import HookedTransformer

hf_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b",
    device_map="cuda:0"
)

device = torch.device("cuda:0")

model = HookedTransformer.from_pretrained("gemma-2-2b", hf_model=hf_model, device=device)

Error message:

  File ".../load_test.py", line 19, in <module>
    model = HookedTransformer.from_pretrained("gemma-2-2b", hf_model=hf_model, device=device)
  File "...site-packages/transformer_lens/HookedTransformer.py", line 1370, in from_pretrained
    model.load_and_process_state_dict(
  File "...site-packages/transformer_lens/HookedTransformer.py", line 1625, in load_and_process_state_dict
    state_dict = self.fold_value_biases(state_dict)
File ".../site-packages/transformer_lens/HookedTransformer.py", line 1875, in fold_value_biases
    folded_b_O = b_O_original + (b_V[:, :, None] * W_O).sum([0, 1])
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

System Info
Describe the characteristic of your environment:

  • Describe how transformer_lens was installed (pip, docker, source, ...)
    Installed using pip, version 2.15.0
  • What OS are you using? (Linux, MacOS, Windows)
    Ubuntu 22.04.04 LTS
  • Python version (We support 3.7--3.10 currently)
    3.10.16

Additional context
Issue with convert_gemma_weights weight conversion function.

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions