@@ -26,8 +26,12 @@ public static Hashtable UnpickleStateDict(string file) {
2626 /// </summary>
2727 /// <param name="stream">Stream of the file to load</param>
2828 /// <param name="leaveOpen">true to leave the stream open after saving the file</param>
29+ /// <param name="skipTensorRead">true to return descriptor objects and streams instead of tensors so that they can be loaded later</param>
2930 /// <returns>The loaded state_dict</returns>
30- public static Hashtable UnpickleStateDict ( Stream stream , bool leaveOpen = false ) {
31+ public static Hashtable UnpickleStateDict ( Stream stream , bool leaveOpen = false , bool skipTensorRead = false ) {
32+ if ( skipTensorRead && ! leaveOpen )
33+ throw new ArgumentException ( "leaveOpen must be true when skipTensorRead is true" ) ;
34+
3135 // Make sure it's a zip file
3236 // If it's not, then it was saved using legacy torch save and we don't support it (yet, at least)
3337 // Check the local file signature
@@ -45,7 +49,7 @@ public static Hashtable UnpickleStateDict(Stream stream, bool leaveOpen = false)
4549
4650 // Create our unpickler with the archive, so it can pull all the relevant files
4751 // using the persistentId
48- var unpickler = new CustomUnpickler ( archive ) ;
52+ var unpickler = new CustomUnpickler ( archive , skipTensorRead ) ;
4953 // The unpickle returns a hash mapping ["key"] to the tensor
5054 return ( Hashtable ) unpickler . load ( pklEntry . Open ( ) ) ;
5155 }
@@ -61,8 +65,11 @@ public static Hashtable UnpickleStateDict(Stream stream, bool leaveOpen = false)
6165 class CustomUnpickler : Unpickler {
6266 readonly ZipArchive _archive ;
6367
64- public CustomUnpickler ( ZipArchive archive ) {
68+ readonly bool _skipTensorRead ;
69+
70+ public CustomUnpickler ( ZipArchive archive , bool skipTensorRead ) {
6571 _archive = archive ;
72+ _skipTensorRead = skipTensorRead ;
6673 }
6774
6875 protected override object persistentLoad ( object pid ) {
@@ -79,20 +86,24 @@ protected override object persistentLoad(object pid) {
7986 string storageType = ( ( ClassDictConstructor ) opid [ 1 ] ) . name ;
8087 // Tuple Item2: key (filename in the archive)
8188 string archiveKey = ( string ) opid [ 2 ] ;
82- // Tuple Item3: location (cpu/gpu), but we always load onto CPU.
89+ // Tuple Item3: location (cpu/gpu), but we always load onto CPU.
8390 // Tuple Item4: numElems (the number of elements in the tensor)
84-
91+
8592 // Convert the storage name into the relevant scalar type (e.g., LongStorage => torch.long)
8693 // and then check how many bytes each element is
8794 var dtype = GetScalarTypeFromStorageName ( storageType ) ;
88-
95+
8996 // Retrieve the entry from the archive
90- var entry = _archive . Entries . First ( f => f . FullName . EndsWith ( $ "data/{ archiveKey } ") ) ;
91-
97+ var entry = _archive . Entries
98+ . Select ( ( archiveEntry , index ) => ( archiveEntry , index ) )
99+ . First ( e => e . archiveEntry . FullName . EndsWith ( $ "data/{ archiveKey } ") ) ;
100+
92101 // Send this back, so our TensorObjectConstructor can create our torch.tensor from the object.
93- return new TensorObject ( ) {
94- data = entry ! . Open ( ) ,
95- dtype = dtype
102+ return new TensorStream {
103+ ArchiveIndex = entry ! . index ,
104+ ArchiveEntry = entry ! . archiveEntry ,
105+ DType = dtype ,
106+ SkipTensorRead = _skipTensorRead ,
96107 } ;
97108 }
98109
@@ -118,7 +129,7 @@ static torch.ScalarType GetScalarTypeFromStorageName(string storage) {
118129 /// <summary>
119130 /// The unpickler implementation requires a __setstate__ function for unpickling an ordered dict, due
120131 /// to the way it was saved. This class is just a regular Hashtable with an implementation for the
121- /// __setstate__.
132+ /// __setstate__.
122133 /// </summary>
123134 class OrderedDict : Hashtable {
124135 public void __setstate__ ( Hashtable arg ) {
@@ -145,27 +156,29 @@ public object construct(object[] args) {
145156 /// </summary>
146157 class TensorObjectConstructor : IObjectConstructor {
147158 public object construct ( object [ ] args ) {
148- // Arg 0: (byte[] data, ScalarType dtype) // returned from our custom pickler
149- var arg0 = ( TensorObject ) args [ 0 ] ;
150- // Arg 1: storage_offset
151- int storageOffset = ( int ) args [ 1 ] ;
152- // Arg 2: tensor_shape
153- var shape = ( ( object [ ] ) args [ 2 ] ) . Select ( i => ( long ) ( int ) i ) . ToArray ( ) ;
154- // Arg 3: stride
155- var stride = ( ( object [ ] ) args [ 3 ] ) . Select ( i => ( long ) ( int ) i ) . ToArray ( ) ;
156- // Arg 4: requires_grad
157- var requiresGrad = ( bool ) args [ 4 ] ;
159+ // Arg 0: returned from our custom pickler
160+ var tensorStream = ( TensorStream ) args [ 0 ] ;
161+
162+ var constructor = new TensorConstructorArgs {
163+ ArchiveIndex = tensorStream . ArchiveIndex ,
164+ Data = tensorStream . ArchiveEntry ! . Open ( ) ,
165+ DType = tensorStream . DType ,
166+ // Arg 1: storage_offset
167+ StorageOffset = ( int ) args [ 1 ] ,
168+ // Arg 2: tensor_shape
169+ Shape = ( ( object [ ] ) args [ 2 ] ) . Select ( i => ( long ) ( int ) i ) . ToArray ( ) ,
170+ // Arg 3: stride
171+ Stride = ( ( object [ ] ) args [ 3 ] ) . Select ( i => ( long ) ( int ) i ) . ToArray ( ) ,
172+ // Arg 4: requires_grad
173+ RequiresGrad = ( bool ) args [ 4 ] ,
174+ } ;
175+
158176 // Arg 5: backward_hooks, we don't support adding them in and it's not recommended
159177 // in PyTorch to serialize them.
160178
161- // If there is no shape, then the shape is just 1
162- // Since we have two operations here - we want to make sure to dispose the temporary.
163- torch . Tensor t = torch . WrappedTensorDisposeScope ( ( ) =>
164- torch . empty ( shape , arg0 . dtype ) . as_strided ( shape , stride , storageOffset ) ) ;
165-
166- t . ReadBytesFromStream ( arg0 . data ) ;
167- arg0 . data . Close ( ) ;
168- return t ;
179+ return tensorStream . SkipTensorRead
180+ ? constructor
181+ : constructor . ReadTensorFromStream ( ) ;
169182 }
170183 }
171184
@@ -182,15 +195,43 @@ public object construct(object[] args) {
182195 }
183196 }
184197
198+ internal record TensorConstructorArgs
199+ {
200+ public int ArchiveIndex { get ; init ; }
201+
202+ public Stream Data { get ; init ; }
203+
204+ public torch . ScalarType DType { get ; init ; }
205+
206+ public int StorageOffset { get ; init ; }
207+
208+ public long [ ] Shape { get ; init ; }
209+
210+ public long [ ] Stride { get ; init ; }
211+
212+ public bool RequiresGrad { get ; init ; }
213+
214+ public torch . Tensor ReadTensorFromStream ( ) {
215+ var temp = torch
216+ . empty ( Shape , DType , device : torch . CPU )
217+ . as_strided ( Shape , Stride , StorageOffset ) ;
218+ temp . ReadBytesFromStream ( Data ) ;
219+ Data . Close ( ) ;
220+
221+ return temp ;
222+ }
223+ }
185224
186225 /// <summary>
187226 /// When the unpickler first loads in the tensor, it only has access to metadata about the storage
188227 /// of the tensor, but not the info about stride/shape etc. That part is done in the TensorReconstructor.
189228 /// Therefore, this class is a simple wrapper for the bytes + dtype of the storage.
190229 /// </summary>
191- class TensorObject {
192- public Stream data { get ; set ; }
193- public torch . ScalarType dtype { get ; set ; }
230+ class TensorStream {
231+ public int ArchiveIndex { get ; init ; }
232+ public ZipArchiveEntry ArchiveEntry { get ; init ; }
233+ public torch . ScalarType DType { get ; init ; }
234+ public bool SkipTensorRead { get ; init ; }
194235 }
195236 }
196- }
237+ }
0 commit comments