diff --git a/test/test_serialization.py b/test/test_serialization.py index 5c40c1285c0324..f260d47f4461b8 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -18,6 +18,7 @@ from torch.testing._internal.common_utils import TestCase, IS_WINDOWS, \ TEST_DILL, run_tests, download_file, BytesIOContext +from torch.testing._internal.common_device_type import instantiate_device_type_tests # These tests were all copied from `test/test_torch.py` at some point, so see # the actual blame, see this revision @@ -583,10 +584,10 @@ def wrapper(*args, **kwargs): def __exit__(self, *args, **kwargs): torch.save = self.torch_save -class TestBothSerialization(TestCase, SerializationMixin): +class TestBothSerialization(TestCase): @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows") - def test_serialization_new_format_old_format_compat(self): - x = [torch.ones(200, 200) for i in range(30)] + def test_serialization_new_format_old_format_compat(self, device): + x = [torch.ones(200, 200, device=device) for i in range(30)] def test(filename): torch.save(x, filename, _use_new_zipfile_serialization=True) @@ -747,6 +748,7 @@ def run(self, *args, **kwargs): with serialization_method(use_zip=True): return super(TestSerialization, self).run(*args, **kwargs) +instantiate_device_type_tests(TestBothSerialization, globals()) if __name__ == '__main__': run_tests() diff --git a/torch/csrc/generic/StorageMethods.cpp b/torch/csrc/generic/StorageMethods.cpp index 0b1be66cf355c8..396e9ffe5cddd7 100644 --- a/torch/csrc/generic/StorageMethods.cpp +++ b/torch/csrc/generic/StorageMethods.cpp @@ -221,9 +221,9 @@ static PyObject * THPStorage_(fromFile)(PyObject *_unused, PyObject *args, PyObj PyObject * THPStorage_(writeFile)(THPStorage *self, PyObject *args) { HANDLE_TH_ERRORS - PyObject *file = PyTuple_GET_ITEM(args, 0); - bool is_real_file = PyTuple_GET_ITEM(args, 1) == Py_True; - bool save_size = PyTuple_GET_ITEM(args, 2) == Py_True; + PyObject *file = PyTuple_GetItem(args, 0); + bool is_real_file = PyTuple_GetItem(args, 1) == Py_True; + bool save_size = PyTuple_GetItem(args, 2) == Py_True; if (!is_real_file) { THPStorage_(writeFileRaw)(self->cdata, file, save_size); diff --git a/torch/serialization.py b/torch/serialization.py index 1c05767922a85f..555c8dc0857b9c 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -488,7 +488,7 @@ def persistent_id(obj): else: # Copy to a buffer, then serialize that buf = io.BytesIO() - storage._write_file(buf, _should_read_directly(buf)) + storage._write_file(buf, _should_read_directly(buf), False) buf_value = buf.getvalue() zip_file.write_record(name, buf_value, len(buf_value))