Skip to content

Commit

Permalink
serialize modules as classes
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#23098

Test Plan: Imported from OSS

Differential Revision: D16383328

Pulled By: suo

fbshipit-source-id: 36389b8e45c3febb7f224cd9c630fe643fa90bef
  • Loading branch information
suo authored and facebook-github-bot committed Aug 11, 2019
1 parent 5ec1c29 commit 77c08aa
Show file tree
Hide file tree
Showing 21 changed files with 756 additions and 240 deletions.
2 changes: 1 addition & 1 deletion caffe2/proto/torch.proto
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ message LibDef {
}

enum ProtoVersion {
PROTO_VERSION_NEWEST = 0x0000000000000005;
PROTO_VERSION_NEWEST = 0x0000000000000006;
}

message ModelDef {
Expand Down
4 changes: 2 additions & 2 deletions test/cpp/jit/test_save_load.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void testSaveExtraFilesHook() {
{
std::stringstream ss;
{
Module m("m");
Module m("__torch__.m");
ExtraFilesMap extra;
extra["metadata.json"] = "abc";
m.save(ss, extra);
Expand All @@ -40,7 +40,7 @@ void testSaveExtraFilesHook() {
SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap {
return {{"secret.json", "topsecret"}};
});
Module m("m");
Module m("__torch__.m");
ExtraFilesMap extra;
extra["metadata.json"] = "abc";
m.save(ss, extra);
Expand Down
52 changes: 34 additions & 18 deletions test/jit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch.jit._logging
import torch.jit.frontend
import torch.jit.quantized
import zipfile

# Testing utils
from common_utils import TestCase, IS_WINDOWS, \
Expand Down Expand Up @@ -90,13 +91,39 @@ def _isHookExceptionOk(self, e):
return False

def _compared_saved_loaded(self, m):
import zipfile
if PY2:
# Disable for Python 2, which does not allow manipulation of multiple objects
# returned by zipfile.open().
# See: https://docs.python.org/2.7/library/zipfile.html#zipfile.ZipFile.open
return

def extract_files(buffer):
# crack open the zip format to get at the main module code
archive = zipfile.ZipFile(buffer)
# check that we have no duplicate names
self.assertEqual(len(set(archive.namelist())), len(archive.namelist()))
files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist()))
# unwrap all the code files into strings
code_files = filter(lambda x: x.endswith('.py'), files)
code_files = map(lambda f: archive.open(f), code_files)
code_files = map(lambda file: "".join([line.decode() for line in file]), code_files)

# unpickled all the debug files
debug_files = filter(lambda f: f.endswith('.debug_pkl'), files)
debug_files = map(lambda f: archive.open(f), debug_files)
debug_files = map(lambda f: pickle.load(f), debug_files)
return code_files, debug_files

# disable the hook while we parse code, otherwise we will re-enter the hook
with torch.jit._disable_emit_hooks():
try:
# short-circuit if this is an empty function or module
if len(m.code) == 0:
# short-circuit if this is an empty module
return
if isinstance(m, torch._C.ScriptModule):
if len(m._method_names()) == 0:
return

# save the module to a buffer
buffer = io.BytesIO()
torch.jit.save(m, buffer)
Expand All @@ -105,14 +132,7 @@ def _compared_saved_loaded(self, m):
# and it's easier to just work with a fresh copy each time.
buffer_copy = buffer.getvalue()

# crack open the zip format to get at the main module code
archive = zipfile.ZipFile(buffer)
# check that we have no duplicate names
self.assertEqual(len(set(archive.namelist())), len(archive.namelist()))
main_module = archive.open('archive/code/archive.py')
main_module_code = "".join([line.decode() for line in main_module])
main_module_debug_file = archive.open('archive/debug/archive.pkl')
main_module_debug = pickle.load(main_module_debug_file)
code_files, debug_files = extract_files(buffer)
except RuntimeError as e:
if not self._isHookExceptionOk(e):
raise
Expand All @@ -128,14 +148,10 @@ def _compared_saved_loaded(self, m):
torch.jit.save(imported, saved_module_buffer_2)

saved_module_buffer_2.seek(0)
archive2 = zipfile.ZipFile(saved_module_buffer_2)
main_module_2 = archive2.open('archive/code/archive.py')
main_module_2_code = "".join([line.decode() for line in main_module_2])
main_module_2_debug_file = archive.open('archive/debug/archive.pkl')
main_module_2_debug = pickle.load(main_module_2_debug_file)

self.assertMultiLineEqual(main_module_code, main_module_2_code)
self.assertEqual(main_module_debug, main_module_2_debug)
code_files_2, debug_files_2 = extract_files(saved_module_buffer_2)

for a, b in zip(code_files, code_files_2):
self.assertMultiLineEqual(a, b)


def emitFunctionHook(self, func):
Expand Down
95 changes: 57 additions & 38 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3398,6 +3398,7 @@ def foobar(xyz):
fc.run(scripted.graph)
fc.run(str(scripted.graph))

@unittest.skipIf(IS_SANDCASTLE, "[serialization forward compat]")
def test_file_line_save_load(self):
class Scripted(torch.jit.ScriptModule):
@torch.jit.script_method
Expand All @@ -3413,7 +3414,7 @@ def forward(self, xyz):
bytesio = io.BytesIO(buffer)
scripted = torch.jit.load(bytesio)

fc = FileCheck().check('code/archive.py:4:10')
fc = FileCheck().check('code/__torch__.py:6:12')
fc.run(scripted.graph)
fc.run(str(scripted.graph))

Expand Down Expand Up @@ -3466,6 +3467,7 @@ def forward(self):
loaded = self.getExportImportCopy(ft)
loaded()

@unittest.skipIf(IS_SANDCASTLE, "[serialization forward compat]")
def test_serialized_source_ranges_dont_jitter(self):
class FooTest3(torch.jit.ScriptModule):
@torch.jit.script_method
Expand Down Expand Up @@ -3498,7 +3500,7 @@ def debug_records_from_mod(mod):
torch.jit.save(ft3, buffer)
buffer.seek(0)
archive = zipfile.ZipFile(buffer)
debug_file = archive.open('archive/debug/archive.pkl')
debug_file = archive.open('archive/code/__torch__.py.debug_pkl')
return pickle.load(debug_file), buffer

records1, buffer = debug_records_from_mod(ft3)
Expand Down Expand Up @@ -3546,14 +3548,18 @@ def debug_records_from_mod(mod):
torch.jit.save(ft3, buffer)
buffer.seek(0)
archive = zipfile.ZipFile(buffer)
debug_file = archive.open('archive/debug/archive.pkl')
return pickle.load(debug_file), buffer

records, _ = debug_records_from_mod(ft3)
for i in range(len(records) - 1):
offset, source_range = records[i]
offset2, source_range2 = records[i + 1]
self.assertNotEqual(source_range, source_range2)
files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist()))
debug_files = filter(lambda f: f.endswith('.debug_pkl'), files)
debug_files = map(lambda f: archive.open(f), debug_files)
debug_files = map(lambda f: pickle.load(f), debug_files)
return list(debug_files)

debug_files = debug_records_from_mod(ft3)
for debug_file in debug_files:
for i in range(len(debug_file) - 1):
offset, source_range = debug_file[i]
offset2, source_range2 = debug_file[i + 1]
self.assertNotEqual(source_range, source_range2)

def test_tensor_shape(self):
x = torch.empty(34, 56, 78)
Expand Down Expand Up @@ -12480,30 +12486,47 @@ def test_bin(x):

@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
def test_get_set_state(self):
class Root(torch.jit.ScriptModule):
__constants__ = ['number']

def __init__(self, number):
super(Root, self).__init__()
self.register_buffer('buffer1', torch.ones(2, 2))
self.register_buffer('buffer2', torch.ones(2, 2))
self.number = number

@torch.jit.script_method
def __getstate__(self):
return (self.buffer1, self.buffer2, 74)

@torch.jit.script_method
def __setstate__(self, state):
self.buffer1 = state[0] + 10
self.buffer2 = state[1] + 10


class M(torch.jit.ScriptModule):
__constants__ = ['number']

def __init__(self, number, submodule=None):
def __init__(self, number, submodule):
super(M, self).__init__()
self.register_buffer('buffer1', torch.ones(2, 2))
self.register_buffer('buffer2', torch.ones(2, 2))
self.number = number
if submodule:
self.submodule = submodule
self.submodule = submodule

@torch.jit.script_method
def __getstate__(self):
# type: () -> Tuple[Tensor, Tensor, int]
return (self.buffer1, self.buffer2, 74)
return (self.buffer1, self.buffer2, 74, self.submodule)

@torch.jit.script_method
def __setstate__(self, state):
# type: (Tuple[Tensor, Tensor, int]) -> None
self.buffer1 = state[0] + 10
self.buffer2 = state[1] + 10
self.submodule = state[3]

with TemporaryFileName() as fname:
m = M(23, submodule=M(99))
m = M(23, submodule=Root(99))
m.save(fname)
loaded = torch.jit.load(fname)

Expand Down Expand Up @@ -12531,19 +12554,18 @@ def forward(self):

@torch.jit.export
def __getstate__(self):
return None
return 5

@torch.jit.export
def __setstate__(self, _):
# type: (None) -> None
self.buffer1 = torch.ones(2, 2) + 10
def __setstate__(self, state):
self.buffer1 = torch.ones(2, 2) + state
self.buffer2 = torch.ones(2, 2) + 10

with TemporaryFileName() as fname:
m = torch.jit.script(NoArgState())
m.save(fname)
loaded = torch.jit.load(fname)
self.assertEqual(loaded.buffer1, torch.ones(2, 2) + 10)
self.assertEqual(loaded.buffer1, torch.ones(2, 2) + 5)
self.assertEqual(loaded.buffer2, torch.ones(2, 2) + 10)


Expand Down Expand Up @@ -13577,32 +13599,28 @@ class M(torch.jit.ScriptModule):
def __init__(self):
super(M, self).__init__()
for name, value, the_type in tester.get_pickle_values():
setattr(self, name, torch.jit.Attribute(value, the_type))
setattr(self, "_" + name, torch.jit.Attribute(value, the_type))

@torch.jit.script_method
def forward(self):
return (self.dict, self.float, self.int, self.bool, self.tuple,
self.list, self.int_list, self.tensor_list, self.bool_list,
self.float_list, self.str_list, self.none)
return (self._dict, self._float, self._int, self._bool, self._tuple,
self._list, self._int_list, self._tensor_list, self._bool_list,
self._float_list, self._str_list, self._none)

with TemporaryFileName() as fname:
M().save(fname)
archive_name = os.path.basename(os.path.normpath(fname))
archive = zipfile.ZipFile(fname, 'r')
pickled_data = archive.read(os.path.join(archive_name, 'attributes.pkl'))
out = pickle.load(io.BytesIO(pickled_data))
loaded = torch.jit.load(fname)

def is_tensor_value(item):
if isinstance(item, torch.Tensor):
return True
if isinstance(item, list):
return is_tensor_value(item[0])
return False

for loaded_item, item in zip(out, self.get_pickle_values()):
if is_tensor_value(item[1]):
for name, value, the_type in self.get_pickle_values():
if is_tensor_value(value):
continue
self.assertEqual(item[1], loaded_item)
self.assertEqual(value, getattr(loaded, "_" + name))

@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
def test_old_models_bc(self):
Expand Down Expand Up @@ -13712,7 +13730,8 @@ def forward(self, x):
def test_script_scope(self):
scripted = torch.jit.script(torch.nn.functional.pad)

@unittest.skipIf(IS_WINDOWS, "NYI: TemporaryFileName on Windows")
# [serialization forward compat]
@unittest.skipIf(IS_SANDCASTLE or IS_WINDOWS, "NYI: TemporaryFileName on Windows")
def test_serialization_sharing(self):
class M(torch.jit.ScriptModule):
def __init__(self):
Expand All @@ -13737,7 +13756,7 @@ def forward(self, key):
m.save(fname)
archive_name = os.path.basename(os.path.normpath(fname))
archive = zipfile.ZipFile(fname, 'r')
pickled_data = archive.read(os.path.join(archive_name, 'attributes.pkl'))
pickled_data = archive.read(os.path.join(archive_name, 'data.pkl'))

out = StringIO()
pickletools.dis(pickled_data, out=out)
Expand Down Expand Up @@ -16298,7 +16317,7 @@ def foo(self, x1, x2):
return torch.neg(x1), self.param, self.const, torch.neg(x2), self.param

@torch.jit.script_method
def wait_script(self, x1, x2):
def forward(self, x1, x2):
fut = torch.jit._fork(self.foo, x1, x2)
y_hat = self.foo(x1, x2)
y = torch.jit._wait(fut)
Expand All @@ -16310,7 +16329,7 @@ def wait_script(self, x1, x2):
m = Mod()

with torch.jit.optimized_execution(False):
y, y_hat = m.wait_script(x1, x2)
y, y_hat = m.forward(x1, x2)

self.assertEqual(y, y_hat)

Expand Down
Loading

0 comments on commit 77c08aa

Please sign in to comment.