Skip to content

Commit

Permalink
Fix map_location in torch.load (pytorch#1006)
Browse files Browse the repository at this point in the history
  • Loading branch information
colesbury authored Mar 15, 2017
1 parent 379ae6d commit c4d1318
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 7 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ torch/csrc/nn/THNN_generic.cpp
torch/csrc/nn/THNN_generic.h
docs/src/**/*
test/data/legacy_modules.t7
test/data/gpu_tensors.pt
test/htmlcov
test/.coverage
*/*.pyc
Expand All @@ -31,4 +32,4 @@ test/.coverage
*/*.so*
*/**/*.so*
*/**/*.dylib*
test/data/legacy_serialized.pt
test/data/legacy_serialized.pt
24 changes: 23 additions & 1 deletion test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2636,7 +2636,7 @@ def test_serialization_backwards_compat(self):
b = [a[i % 2] for i in range(4)]
b += [a[0].storage()]
b += [a[0].storage()[1:4]]
DATA_URL = 'https://s3.amazonaws.com/pytorch/legacy_serialized.pt'
DATA_URL = 'https://download.pytorch.org/test_data/legacy_serialized.pt'
data_dir = os.path.join(os.path.dirname(__file__), 'data')
test_file_path = os.path.join(data_dir, 'legacy_serialized.pt')
succ = download_file(DATA_URL, test_file_path)
Expand Down Expand Up @@ -2694,6 +2694,28 @@ def import_module(name, filename):
self.assertEquals(len(w), 1)
self.assertTrue(w[0].category, 'SourceChangeWarning')

def test_serialization_map_location(self):
DATA_URL = 'https://download.pytorch.org/test_data/gpu_tensors.pt'
data_dir = os.path.join(os.path.dirname(__file__), 'data')
test_file_path = os.path.join(data_dir, 'gpu_tensors.pt')
succ = download_file(DATA_URL, test_file_path)
if not succ:
warnings.warn(
"Couldn't download the test file for map_location! "
"Tests will be incomplete!", RuntimeWarning)
return

def map_location(storage, loc):
return storage

tensor = torch.load(test_file_path, map_location=map_location)
self.assertEqual(type(tensor), torch.FloatTensor)
self.assertEqual(tensor, torch.FloatTensor([[1.0, 2.0], [3.0, 4.0]]))

tensor = torch.load(test_file_path, map_location={'cuda:0': 'cpu'})
self.assertEqual(type(tensor), torch.FloatTensor)
self.assertEqual(tensor, torch.FloatTensor([[1.0, 2.0], [3.0, 4.0]]))

def test_from_buffer(self):
a = bytearray([1, 2, 3, 4])
self.assertEqual(torch.ByteStorage.from_buffer(a).tolist(), [1, 2, 3, 4])
Expand Down
2 changes: 1 addition & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def do_test(self):

@classmethod
def init(cls):
DATA_URL = 'https://s3.amazonaws.com/pytorch/legacy_modules.t7'
DATA_URL = 'https://download.pytorch.org/test_data/legacy_modules.t7'
data_dir = os.path.join(os.path.dirname(__file__), 'data')
test_file_path = os.path.join(data_dir, 'legacy_modules.t7')
succ = download_file(DATA_URL, test_file_path)
Expand Down
8 changes: 8 additions & 0 deletions torch/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import importlib


def _type(self, new_type=None, async=False):
Expand Down Expand Up @@ -64,6 +65,13 @@ def _cuda(self, device=None, async=False):
return new_type(self.size()).copy_(self, async)


def _rebuild_tensor(storage, storage_offset, size, stride):
class_name = storage.__class__.__name__.replace('Storage', 'Tensor')
module = importlib.import_module(storage.__module__)
tensor_class = getattr(module, class_name)
return tensor_class().set_(storage, storage_offset, size, stride)


def _range(*args, **kwargs):
return __builtins__['range'](*args, **kwargs)

Expand Down
11 changes: 9 additions & 2 deletions torch/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def persistent_id(obj):


def load(f, map_location=None, pickle_module=pickle):
"""Loads an object saved with torch.save from a disk file.
"""Loads an object saved with :func:`torch.save` from a file.
torch.load can dynamically remap storages to be loaded on a different device
using the map_location argument. If it's a callable, it will be called with
Expand All @@ -213,6 +213,13 @@ def load(f, map_location=None, pickle_module=pickle):
map_location: a function or a dict specifying how to remap storage locations
pickle_module: module used for unpickling metadata and objects (has to match
the pickle_module used to serialize file)
Example:
>>> torch.load('tensors.pt')
# Load all tensors onto the CPU
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
# Map tensors from GPU 1 to GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
"""
new_fd = False
if isinstance(f, str) or (sys.version_info[0] == 2 and isinstance(f, unicode)):
Expand All @@ -237,7 +244,7 @@ def restore_location(storage, location):
else:
def restore_location(storage, location):
result = map_location(storage, location)
if not result:
if result is None:
result = default_restore_location(storage, location)
return result

Expand Down
6 changes: 4 additions & 2 deletions torch/tensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from . import _tensor_str
from ._utils import _type, _cuda, _range
from ._utils import _type, _cuda, _range, _rebuild_tensor
import sys


Expand Down Expand Up @@ -104,7 +104,9 @@ def __deepcopy__(self, _memo):
return new_tensor

def __reduce__(self):
return (type(self), (), self.__getstate__())
# NOTE: _rebuild_tensor does not call __setstate__
args = self.__getstate__()
return (_rebuild_tensor, args)

def __getstate__(self):
return (self.storage(),
Expand Down

0 comments on commit c4d1318

Please sign in to comment.