Skip to content

Commit 4ad4f6d

Browse files
hold references to storages during TorchScript serializaiton (pytorch#59672)
Fixes issue for serialization problem caused by using memory address of storages for mobile and torch.package models. - pytorch#59642 hold references to storages during TorchScript serialization Uses StorageContext to hold a reference to all storages seen during TorchScript serialization to allow for tensors to be created/destroyed during serialization process. Tracking of the storages solves for the ABA memory problem.
1 parent 90e6773 commit 4ad4f6d

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

torch/csrc/jit/serialization/export.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <caffe2/serialize/inline_container.h>
44
#include <torch/csrc/jit/api/module.h>
55
#include <torch/csrc/jit/ir/ir.h>
6+
#include <torch/csrc/jit/serialization/import.h>
67
#include <torch/csrc/jit/serialization/pickler.h>
78
#include <torch/csrc/jit/serialization/python_print.h>
89
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
@@ -97,6 +98,10 @@ class TORCH_API ScriptModuleSerializer {
9798
// qualifier, e.g. '__torch__.Bar' -> PythonPrint for the file that will be
9899
// created
99100
OrderedDict<std::string, PythonPrint> file_streams_;
101+
// Used to keep references of storages around during serialization to solve
102+
// for ABA memory reuse problem hit when storages are created/destroyed
103+
// during serializaiton process.
104+
StorageContext storage_context_;
100105

101106
// Uniquely identifies a SourceRange in a model.
102107
// SourceRanges are associated with Nodes of Graphs.

torch/csrc/jit/serialization/export_module.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -411,10 +411,11 @@ void ScriptModuleSerializer::writeArchive(
411411
[&](const at::Tensor& tensor) {
412412
// returns a string to use in picker.cpp as storage obj key
413413
if (tensor_cdata_naming_scheme) {
414-
tensor_names.push_back(
414+
std::string string_id =
415415
std::to_string(reinterpret_cast<std::intptr_t>(
416-
tensor.storage().unsafeGetStorageImpl())) +
417-
".storage");
416+
tensor.storage().unsafeGetStorageImpl()));
417+
tensor_names.push_back(string_id + ".storage");
418+
storage_context_.addStorage(string_id, tensor.storage());
418419
} else {
419420
tensor_names.push_back(std::to_string(tensor_names.size()));
420421
}

0 commit comments

Comments
 (0)