Closed
Description
I am trying a generic code to load the state-dict of a model on CPU/GPU. It works fine for other vision type models but fails for VGG models on GPU machines
Here is a sample code I am trying out :
import torch
map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
model_pt_path = 'vgg13-c768596a.pth'
state_dict = torch.load(model_pt_path, map_location=map_location)
The above code fails with following error
Traceback (most recent call last):
File "test_vgg.py", line 4, in <module>
state_dict = torch.load(model_pt_path, map_location=map_location)
File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 593, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 773, in _legacy_load
result = unpickler.load()
File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/tensor.py", line 147, in __setstate__
self.set_(*state)
RuntimeError: Expected object of device type cpu but got device type cuda for argument #2 'source'
I tried this same code with Densenet/Alexnet/Squeezenet models and it works fine on both GPU and CPU machines
Environment information :
OS : Ubuntu 18.0.4
PyTorch Version : 1.5.1
TorchVision Version : 0.6.1