If you are submitting a bug report, please fill in the following details and use the tag [bug].
Describe the bug
When passing an existing AutoModelForCausalLM instance of Gemma2-2b that's on a GPU device to the TransformerLens.from_pretrained function, an error arises in fold_value_biases because the model biases in the state_dict are on the cpu.
Code example
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from transformer_lens import HookedTransformer
hf_model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2-2b",
device_map="cuda:0"
)
device = torch.device("cuda:0")
model = HookedTransformer.from_pretrained("gemma-2-2b", hf_model=hf_model, device=device)
Error message:
File ".../load_test.py", line 19, in <module>
model = HookedTransformer.from_pretrained("gemma-2-2b", hf_model=hf_model, device=device)
File "...site-packages/transformer_lens/HookedTransformer.py", line 1370, in from_pretrained
model.load_and_process_state_dict(
File "...site-packages/transformer_lens/HookedTransformer.py", line 1625, in load_and_process_state_dict
state_dict = self.fold_value_biases(state_dict)
File ".../site-packages/transformer_lens/HookedTransformer.py", line 1875, in fold_value_biases
folded_b_O = b_O_original + (b_V[:, :, None] * W_O).sum([0, 1])
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
System Info
Describe the characteristic of your environment:
- Describe how
transformer_lens was installed (pip, docker, source, ...)
Installed using pip, version 2.15.0
- What OS are you using? (Linux, MacOS, Windows)
Ubuntu 22.04.04 LTS
- Python version (We support 3.7--3.10 currently)
3.10.16
Additional context
Issue with convert_gemma_weights weight conversion function.
Checklist
If you are submitting a bug report, please fill in the following details and use the tag [bug].
Describe the bug
When passing an existing
AutoModelForCausalLMinstance of Gemma2-2b that's on a GPU device to theTransformerLens.from_pretrainedfunction, an error arises infold_value_biasesbecause the model biases in the state_dict are on the cpu.Code example
Error message:
System Info
Describe the characteristic of your environment:
transformer_lenswas installed (pip, docker, source, ...)Installed using pip, version 2.15.0
Ubuntu 22.04.04 LTS
3.10.16
Additional context
Issue with
convert_gemma_weightsweight conversion function.Checklist