Skip to content

Commit a9463d7

Browse files
Merge pull request #20 from ejhg/ejhg-optimize-load_py
Optimize load_py for memory and speed
2 parents 70f4c74 + 568d0ae commit a9463d7

2 files changed

Lines changed: 161 additions & 51 deletions

File tree

TorchSharp.PyBridge/PyBridgeModuleExtensions.cs

Lines changed: 86 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using System.Collections;
2-
using System.IO;
31
using System.Text.Json;
42
using System.Text.Json.Nodes;
53
using TqdmSharp;
@@ -116,35 +114,106 @@ public static Module load_py(this Module module, string location, bool strict =
116114
/// <remarks>
117115
/// This method only supports loading the newer format used by `torch.save`, using a zip file.
118116
/// The model will be fully loaded and all the validation checks will only run after the state
119-
/// dictionary has been fully loaded.
117+
/// dictionary has been fully loaded.
120118
/// </remarks>
121119
public static Module load_py(this Module module, System.IO.Stream stream, bool strict = true, IList<string>? skip = null, Dictionary<string, bool>? loadedParameters = null, bool leaveOpen = false) {
122-
// Create a dispose score so that we don't keep anyof the loaded tensors past this function
120+
// Create a dispose score so that we don't keep any of the loaded tensors past this function
123121
using var d = torch.NewDisposeScope();
124122
using var d2 = torch.no_grad(); // To circumvent a bug introduced in 0.102.0
125123

126-
// Unpickle the state dictionary into memory
127-
var stateHashtable = PyTorchUnpickler.UnpickleStateDict(stream, leaveOpen);
124+
// Unpickle the state dictionary into memory.
125+
// Keep stream open because tensors will not get deserialized yet.
126+
var unpickled = PyTorchUnpickler.UnpickleStateDict(stream, leaveOpen: true, skipTensorRead: true);
128127

129128
// Convert the hashtable to a dictionary of string->tensor
130-
var stateDict = new Dictionary<string, torch.Tensor>();
131-
foreach (string key in stateHashtable.Keys)
132-
stateDict.Add(key, (torch.Tensor)stateHashtable[key]!);
129+
var unpickledConstructors = new Dictionary<string, PyTorchUnpickler.TensorConstructorArgs>();
133130

134-
// Load it in using the builtin function
135-
var (_, unexpectedKeys) = module.load_state_dict(stateDict, strict, skip);
131+
foreach (string key in unpickled.Keys) {
132+
unpickledConstructors.Add(key, (PyTorchUnpickler.TensorConstructorArgs)unpickled[key]!);
133+
}
136134

137-
// Fill in the loadedParameters dictionary, if relevant
138-
if (loadedParameters is not null) {
139-
foreach (string key in stateDict.Keys)
140-
loadedParameters[key] = true;
141-
foreach (string key in unexpectedKeys)
142-
loadedParameters[key] = false;
135+
var (_, unexpectedKeys) = load_state_dict(module, unpickledConstructors, strict, skip);
136+
137+
if (!leaveOpen) {
138+
// Close stream now that tensor streams have been read.
139+
stream.Close ();
140+
}
141+
142+
if (loadedParameters is null) {
143+
return module;
144+
}
145+
146+
// Fill in the loadedParameters dictionary
147+
foreach (var key in unpickledConstructors.Keys) {
148+
loadedParameters[key] = true;
149+
}
150+
151+
foreach (var key in unexpectedKeys) {
152+
loadedParameters[key] = false;
143153
}
144154

145155
return module;
146156
}
147157

158+
/// <summary>
159+
/// Mirrors the implementation of module.load_state_dict but performs tensor reading
160+
/// with less intermediate memory overhead.
161+
/// </summary>
162+
static (IList<string> missing_keys, IList<string> unexpected_keys) load_state_dict(
163+
Module module,
164+
Dictionary<string, PyTorchUnpickler.TensorConstructorArgs> unpickled,
165+
bool strict = true,
166+
IList<string> skip = null
167+
) {
168+
var missingKeys = new List<string>();
169+
var unexpectedKeys = new List<string>();
170+
skip ??= Array.Empty<string>();
171+
172+
var state = module.state_dict();
173+
174+
foreach (string key in unpickled.Keys) {
175+
if (!skip.Contains(key) && !state.ContainsKey(key))
176+
unexpectedKeys.Add(key);
177+
}
178+
179+
foreach (string key in state.Keys) {
180+
if (!skip.Contains(key) && !unpickled.ContainsKey(key)) {
181+
missingKeys.Add(key);
182+
}
183+
}
184+
185+
if (strict && (missingKeys.Count > 0 || unexpectedKeys.Count > 0)) {
186+
throw new InvalidOperationException("The loaded state_dict is not identical to the target dictionary.");
187+
}
188+
189+
var inputStreams = unpickled
190+
.Where(e => state.ContainsKey(e.Key))
191+
// Avoid random stream seeks by reading archive files in the order that they are stored.
192+
.OrderBy(e => e.Value.ArchiveIndex)
193+
.ToArray();
194+
195+
foreach (var (key, constructor) in inputStreams) {
196+
var target = state[key];
197+
target.with_requires_grad(constructor.RequiresGrad);
198+
199+
if (constructor.DType == state[key].dtype) {
200+
using var stream = constructor.Data;
201+
// Read directly into target tensor.
202+
target
203+
.as_strided(constructor.Shape, constructor.Stride, constructor.StorageOffset)
204+
.ReadBytesFromStream(stream);
205+
}
206+
else {
207+
// Type conversion with intermediate tensor required.
208+
// This will load onto cpu first before copying to target.
209+
using torch.Tensor temp = constructor.ReadTensorFromStream();
210+
state[key].copy_(temp);
211+
}
212+
}
213+
214+
return (missingKeys, unexpectedKeys);
215+
}
216+
148217
/// <summary>
149218
/// Load the parameters and buffers from a file saved using the safetensors format (https://github.com/huggingface/safetensors)
150219
/// </summary>

TorchSharp.PyBridge/PyTorchUnpickler.cs

Lines changed: 75 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,12 @@ public static Hashtable UnpickleStateDict(string file) {
2626
/// </summary>
2727
/// <param name="stream">Stream of the file to load</param>
2828
/// <param name="leaveOpen">true to leave the stream open after saving the file</param>
29+
/// <param name="skipTensorRead">true to return descriptor objects and streams instead of tensors so that they can be loaded later</param>
2930
/// <returns>The loaded state_dict</returns>
30-
public static Hashtable UnpickleStateDict(Stream stream, bool leaveOpen = false) {
31+
public static Hashtable UnpickleStateDict(Stream stream, bool leaveOpen = false, bool skipTensorRead = false) {
32+
if (skipTensorRead && !leaveOpen)
33+
throw new ArgumentException("leaveOpen must be true when skipTensorRead is true");
34+
3135
// Make sure it's a zip file
3236
// If it's not, then it was saved using legacy torch save and we don't support it (yet, at least)
3337
// Check the local file signature
@@ -45,7 +49,7 @@ public static Hashtable UnpickleStateDict(Stream stream, bool leaveOpen = false)
4549

4650
// Create our unpickler with the archive, so it can pull all the relevant files
4751
// using the persistentId
48-
var unpickler = new CustomUnpickler(archive);
52+
var unpickler = new CustomUnpickler(archive, skipTensorRead);
4953
// The unpickle returns a hash mapping ["key"] to the tensor
5054
return (Hashtable)unpickler.load(pklEntry.Open());
5155
}
@@ -61,8 +65,11 @@ public static Hashtable UnpickleStateDict(Stream stream, bool leaveOpen = false)
6165
class CustomUnpickler : Unpickler {
6266
readonly ZipArchive _archive;
6367

64-
public CustomUnpickler(ZipArchive archive) {
68+
readonly bool _skipTensorRead;
69+
70+
public CustomUnpickler(ZipArchive archive, bool skipTensorRead) {
6571
_archive = archive;
72+
_skipTensorRead = skipTensorRead;
6673
}
6774

6875
protected override object persistentLoad(object pid) {
@@ -79,20 +86,24 @@ protected override object persistentLoad(object pid) {
7986
string storageType = ((ClassDictConstructor)opid[1]).name;
8087
// Tuple Item2: key (filename in the archive)
8188
string archiveKey = (string)opid[2];
82-
// Tuple Item3: location (cpu/gpu), but we always load onto CPU.
89+
// Tuple Item3: location (cpu/gpu), but we always load onto CPU.
8390
// Tuple Item4: numElems (the number of elements in the tensor)
84-
91+
8592
// Convert the storage name into the relevant scalar type (e.g., LongStorage => torch.long)
8693
// and then check how many bytes each element is
8794
var dtype = GetScalarTypeFromStorageName(storageType);
88-
95+
8996
// Retrieve the entry from the archive
90-
var entry = _archive.Entries.First(f => f.FullName.EndsWith($"data/{archiveKey}"));
91-
97+
var entry = _archive.Entries
98+
.Select((archiveEntry, index) => (archiveEntry, index))
99+
.First(e => e.archiveEntry.FullName.EndsWith($"data/{archiveKey}"));
100+
92101
// Send this back, so our TensorObjectConstructor can create our torch.tensor from the object.
93-
return new TensorObject() {
94-
data = entry!.Open(),
95-
dtype = dtype
102+
return new TensorStream {
103+
ArchiveIndex = entry!.index,
104+
ArchiveEntry = entry!.archiveEntry,
105+
DType = dtype,
106+
SkipTensorRead = _skipTensorRead,
96107
};
97108
}
98109

@@ -118,7 +129,7 @@ static torch.ScalarType GetScalarTypeFromStorageName(string storage) {
118129
/// <summary>
119130
/// The unpickler implementation requires a __setstate__ function for unpickling an ordered dict, due
120131
/// to the way it was saved. This class is just a regular Hashtable with an implementation for the
121-
/// __setstate__.
132+
/// __setstate__.
122133
/// </summary>
123134
class OrderedDict : Hashtable {
124135
public void __setstate__(Hashtable arg) {
@@ -145,27 +156,29 @@ public object construct(object[] args) {
145156
/// </summary>
146157
class TensorObjectConstructor : IObjectConstructor {
147158
public object construct(object[] args) {
148-
// Arg 0: (byte[] data, ScalarType dtype) // returned from our custom pickler
149-
var arg0 = (TensorObject)args[0];
150-
// Arg 1: storage_offset
151-
int storageOffset = (int)args[1];
152-
// Arg 2: tensor_shape
153-
var shape = ((object[])args[2]).Select(i => (long)(int)i).ToArray();
154-
// Arg 3: stride
155-
var stride = ((object[])args[3]).Select(i => (long)(int)i).ToArray();
156-
// Arg 4: requires_grad
157-
var requiresGrad = (bool)args[4];
159+
// Arg 0: returned from our custom pickler
160+
var tensorStream = (TensorStream)args[0];
161+
162+
var constructor = new TensorConstructorArgs {
163+
ArchiveIndex = tensorStream.ArchiveIndex,
164+
Data = tensorStream.ArchiveEntry!.Open(),
165+
DType = tensorStream.DType,
166+
// Arg 1: storage_offset
167+
StorageOffset = (int)args[1],
168+
// Arg 2: tensor_shape
169+
Shape = ((object[])args[2]).Select(i => (long)(int)i).ToArray(),
170+
// Arg 3: stride
171+
Stride = ((object[])args[3]).Select(i => (long)(int)i).ToArray(),
172+
// Arg 4: requires_grad
173+
RequiresGrad = (bool)args[4],
174+
};
175+
158176
// Arg 5: backward_hooks, we don't support adding them in and it's not recommended
159177
// in PyTorch to serialize them.
160178

161-
// If there is no shape, then the shape is just 1
162-
// Since we have two operations here - we want to make sure to dispose the temporary.
163-
torch.Tensor t = torch.WrappedTensorDisposeScope(() =>
164-
torch.empty(shape, arg0.dtype).as_strided(shape, stride, storageOffset));
165-
166-
t.ReadBytesFromStream(arg0.data);
167-
arg0.data.Close();
168-
return t;
179+
return tensorStream.SkipTensorRead
180+
? constructor
181+
: constructor.ReadTensorFromStream();
169182
}
170183
}
171184

@@ -182,15 +195,43 @@ public object construct(object[] args) {
182195
}
183196
}
184197

198+
internal record TensorConstructorArgs
199+
{
200+
public int ArchiveIndex { get; init; }
201+
202+
public Stream Data { get; init; }
203+
204+
public torch.ScalarType DType { get; init; }
205+
206+
public int StorageOffset { get; init; }
207+
208+
public long[] Shape { get; init; }
209+
210+
public long[] Stride { get; init; }
211+
212+
public bool RequiresGrad { get; init; }
213+
214+
public torch.Tensor ReadTensorFromStream() {
215+
var temp = torch
216+
.empty(Shape, DType, device: torch.CPU)
217+
.as_strided(Shape, Stride, StorageOffset);
218+
temp.ReadBytesFromStream(Data);
219+
Data.Close();
220+
221+
return temp;
222+
}
223+
}
185224

186225
/// <summary>
187226
/// When the unpickler first loads in the tensor, it only has access to metadata about the storage
188227
/// of the tensor, but not the info about stride/shape etc. That part is done in the TensorReconstructor.
189228
/// Therefore, this class is a simple wrapper for the bytes + dtype of the storage.
190229
/// </summary>
191-
class TensorObject {
192-
public Stream data { get; set; }
193-
public torch.ScalarType dtype { get; set; }
230+
class TensorStream {
231+
public int ArchiveIndex { get; init; }
232+
public ZipArchiveEntry ArchiveEntry { get; init; }
233+
public torch.ScalarType DType { get; init; }
234+
public bool SkipTensorRead { get; init; }
194235
}
195236
}
196-
}
237+
}

0 commit comments

Comments
 (0)