Skip to content

Conversation

@Neelectric
Copy link

Some .pt files on the huggingface space (for eg this one, looks like all pythia lenses might be affected? gpt2-large seems to work) contain "cuda:0". See the following screenshot (forgive me for opening a .pt file in vim, it won't happen again):

image

This means that when th.load() gets called and the Unpickling occurs, it tries to load this onto "cuda:0" regardless of the torch backend on the machine, or the device the model itself is loaded on. This is a shame, as Mac users relying on MPS won't be able to load these lenses.

The 2-line change in this commit checks what device the loaded embedding is on, and then loads the state (and as such the .pt file) onto this same device using map_location. I think this is quite an elegant solution that gives the user the flexibility to load their model onto the device of their choosing, and copy the lens onto the same device. Please let me know if another solution would be preferable, and do forgive me if I have gone against any conventions of contributing to OSS. This is my first pull request ever, I mean well!

…he same device that the model is on using map_location
@Neelectric
Copy link
Author

FYI, without this change, the following lines in notebooks/interactive.ipynb (ie. the default) works just fine:

model = AutoModelForCausalLM.from_pretrained('gpt2-large')
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained('gpt2-large')

However, when instead loading a pythia model...

model = AutoModelForCausalLM.from_pretrained('EleutherAI/pythia-2.8b-deduped')
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/pythia-2.8b-deduped')

... and then wanting to load the corresponding lens, we get

/opt/homebrew/Caskroom/miniconda/base/envs/pythia_trials/lib/python3.12/site-packages/tuned_lens/nn/lenses.py:277: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state = th.load(ckpt_path, **th_load_kwargs)
Traceback (most recent call last):
  File "/Users/user/repos/pythia_trials/tuned_lens_example.py", line 27, in <module>
    tuned_lens = TunedLens.from_model_and_pretrained(model)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/miniconda/base/envs/pythia_trials/lib/python3.12/site-packages/tuned_lens/nn/lenses.py", line 229, in from_model_and_pretrained
    return cls.from_unembed_and_pretrained(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/miniconda/base/envs/pythia_trials/lib/python3.12/site-packages/tuned_lens/nn/lenses.py", line 277, in from_unembed_and_pretrained
    state = th.load(ckpt_path, **th_load_kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/miniconda/base/envs/pythia_trials/lib/python3.12/site-packages/torch/serialization.py", line 1097, in load
    return _load(
           ^^^^^^
  File "/opt/homebrew/Caskroom/miniconda/base/envs/pythia_trials/lib/python3.12/site-packages/torch/serialization.py", line 1525, in _load
    result = unpickler.load()
             ^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/miniconda/base/envs/pythia_trials/lib/python3.12/site-packages/torch/serialization.py", line 1492, in persistent_load
    typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/miniconda/base/envs/pythia_trials/lib/python3.12/site-packages/torch/serialization.py", line 1466, in load_tensor
    wrap_storage=restore_location(storage, location),
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/miniconda/base/envs/pythia_trials/lib/python3.12/site-packages/torch/serialization.py", line 414, in default_restore_location
    result = fn(storage, location)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/miniconda/base/envs/pythia_trials/lib/python3.12/site-packages/torch/serialization.py", line 391, in _deserialize
    device = _validate_device(location, backend_name)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/miniconda/base/envs/pythia_trials/lib/python3.12/site-packages/torch/serialization.py", line 364, in _validate_device
    raise RuntimeError(f'Attempting to deserialize object on a {backend_name.upper()} '
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

This occurs regardless of using device = torch.device('cpu') or device = torch.device('mps'), on the following environment:

  • OS: macOS Sonoma build 23H124
  • Python version: 3.12.5
  • CUDA version: Not applicable
  • PyTorch version: 2.4.1

After the 2-line change in this PR, the lens is loaded onto the same device as the model instead and the error goes away :)

@norabelrose
Copy link
Collaborator

Thanks for the PR, I haven't had a chance to look at this in detail but it looks like sometimes th_load_kwargs actually does contain a map_location key, in which case the tests fail because there are two copies of the map_location kwarg. One way to fix this would be to just do

th_load_kwargs['map_location'] = device

but maybe it would be better to change the code that's generating the th_load_kwargs dict.

@Neelectric
Copy link
Author

Thank you for replying so quickly! Didn't know this would cause tests to fail, very impressive test coverage from you folks. In hindsight not unexpected cuz this approach is very hacky, using th_load_kwargs['map_location'] = device does seem like a smarter workaround. Apologies for the newbie question but is there a way for me to trigger the tests/pre-merge checks so I can find a solution that won't fail? Or does that require maintainer/collaborator privileges?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants