-
Couldn't load subscription status.
- Fork 59
Make lenses device agnostic using map_location in torch.load() #134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…he same device that the model is on using map_location
|
FYI, without this change, the following lines in notebooks/interactive.ipynb (ie. the default) works just fine: However, when instead loading a pythia model... ... and then wanting to load the corresponding lens, we get This occurs regardless of using
After the 2-line change in this PR, the lens is loaded onto the same device as the model instead and the error goes away :) |
|
Thanks for the PR, I haven't had a chance to look at this in detail but it looks like sometimes th_load_kwargs['map_location'] = devicebut maybe it would be better to change the code that's generating the |
|
Thank you for replying so quickly! Didn't know this would cause tests to fail, very impressive test coverage from you folks. In hindsight not unexpected cuz this approach is very hacky, using |
Some .pt files on the huggingface space (for eg this one, looks like all pythia lenses might be affected? gpt2-large seems to work) contain "cuda:0". See the following screenshot (forgive me for opening a .pt file in vim, it won't happen again):
This means that when th.load() gets called and the Unpickling occurs, it tries to load this onto "cuda:0" regardless of the torch backend on the machine, or the device the model itself is loaded on. This is a shame, as Mac users relying on MPS won't be able to load these lenses.
The 2-line change in this commit checks what device the loaded embedding is on, and then loads the state (and as such the .pt file) onto this same device using map_location. I think this is quite an elegant solution that gives the user the flexibility to load their model onto the device of their choosing, and copy the lens onto the same device. Please let me know if another solution would be preferable, and do forgive me if I have gone against any conventions of contributing to OSS. This is my first pull request ever, I mean well!