@@ -148,6 +148,7 @@ def __init__(
148
148
self ,
149
149
path : str ,
150
150
pg : Optional [dist .ProcessGroup ] = None ,
151
+ storage_options : Optional [Dict [str , Any ]] = None ,
151
152
) -> None :
152
153
"""
153
154
Initializes the reference to an existing snapshot.
@@ -158,10 +159,13 @@ def __init__(
158
159
When unspecified:
159
160
- If distributed is initialized, the global process group will be used.
160
161
- If distributed is not initialized, single process is assumed.
162
+ storage_options: Additional keyword options for the StoragePlugin to use.
163
+ See each StoragePlugin's documentation for customizations.
161
164
"""
162
165
self .path : str = path
163
166
self .pg : Optional [dist .ProcessGroup ] = pg
164
167
self ._metadata : Optional [SnapshotMetadata ] = None
168
+ self ._storage_options = storage_options
165
169
166
170
@classmethod
167
171
def take (
@@ -170,6 +174,7 @@ def take(
170
174
app_state : AppState ,
171
175
pg : Optional [dist .ProcessGroup ] = None ,
172
176
replicated : Optional [List [str ]] = None ,
177
+ storage_options : Optional [Dict [str , Any ]] = None ,
173
178
_custom_tensor_prepare_func : Optional [
174
179
Callable [[str , torch .Tensor , bool ], torch .Tensor ]
175
180
] = None ,
@@ -187,6 +192,8 @@ def take(
187
192
replicated: A list of glob patterns for hinting the matching paths
188
193
as replicated. Note that patterns not specified by all ranks
189
194
are ignored.
195
+ storage_options: Additional keyword options for the StoragePlugin to use.
196
+ See each StoragePlugin's documentation for customizations.
190
197
191
198
Returns:
192
199
The newly taken snapshot.
@@ -204,7 +211,7 @@ def take(
204
211
replicated = replicated or [],
205
212
)
206
213
storage = url_to_storage_plugin_in_event_loop (
207
- url_path = path , event_loop = event_loop
214
+ url_path = path , event_loop = event_loop , storage_options = storage_options
208
215
)
209
216
pending_io_work , metadata = cls ._take_impl (
210
217
path = path ,
@@ -229,7 +236,7 @@ def take(
229
236
230
237
storage .sync_close (event_loop = event_loop )
231
238
event_loop .close ()
232
- snapshot = cls (path = path , pg = pg )
239
+ snapshot = cls (path = path , pg = pg , storage_options = storage_options )
233
240
snapshot ._metadata = metadata
234
241
return snapshot
235
242
@@ -240,6 +247,7 @@ def async_take(
240
247
app_state : AppState ,
241
248
pg : Optional [dist .ProcessGroup ] = None ,
242
249
replicated : Optional [List [str ]] = None ,
250
+ storage_options : Optional [Dict [str , Any ]] = None ,
243
251
_custom_tensor_prepare_func : Optional [
244
252
Callable [[str , torch .Tensor , bool ], torch .Tensor ]
245
253
] = None ,
@@ -262,6 +270,8 @@ def async_take(
262
270
replicated: A list of glob patterns for hinting the matching paths
263
271
as replicated. Note that patterns not specified by all ranks
264
272
are ignored.
273
+ storage_options: Additional keyword options for the StoragePlugin to use.
274
+ See each StoragePlugin's documentation for customizations.
265
275
266
276
Returns:
267
277
A handle with which the newly taken snapshot can be obtained via
@@ -281,7 +291,7 @@ def async_take(
281
291
replicated = replicated or [],
282
292
)
283
293
storage = url_to_storage_plugin_in_event_loop (
284
- url_path = path , event_loop = event_loop
294
+ url_path = path , event_loop = event_loop , storage_options = storage_options
285
295
)
286
296
287
297
pending_io_work , metadata = cls ._take_impl (
@@ -302,6 +312,7 @@ def async_take(
302
312
metadata = metadata ,
303
313
storage = storage ,
304
314
event_loop = event_loop ,
315
+ storage_options = storage_options ,
305
316
)
306
317
307
318
@classmethod
@@ -430,6 +441,7 @@ def restore(self, app_state: AppState) -> None:
430
441
431
442
Args:
432
443
app_state: The program state to restore from the snapshot.
444
+
433
445
"""
434
446
torch ._C ._log_api_usage_once ("torchsnapshot.Snapshot.restore" )
435
447
self ._validate_app_state (app_state )
@@ -438,7 +450,9 @@ def restore(self, app_state: AppState) -> None:
438
450
pg_wrapper = PGWrapper (self .pg )
439
451
rank = pg_wrapper .get_rank ()
440
452
storage = url_to_storage_plugin_in_event_loop (
441
- url_path = self .path , event_loop = event_loop
453
+ url_path = self .path ,
454
+ event_loop = event_loop ,
455
+ storage_options = self ._storage_options ,
442
456
)
443
457
444
458
app_state = app_state .copy ()
@@ -480,7 +494,9 @@ def metadata(self) -> SnapshotMetadata:
480
494
if self ._metadata is None :
481
495
event_loop = asyncio .new_event_loop ()
482
496
storage = url_to_storage_plugin_in_event_loop (
483
- url_path = self .path , event_loop = event_loop
497
+ url_path = self .path ,
498
+ event_loop = event_loop ,
499
+ storage_options = self ._storage_options ,
484
500
)
485
501
self ._metadata = self ._read_snapshot_metadata (
486
502
storage = storage , event_loop = event_loop
@@ -550,7 +566,9 @@ def read_object(
550
566
event_loop = asyncio .new_event_loop ()
551
567
pg_wrapper = PGWrapper (self .pg )
552
568
storage = url_to_storage_plugin_in_event_loop (
553
- url_path = self .path , event_loop = event_loop
569
+ url_path = self .path ,
570
+ event_loop = event_loop ,
571
+ storage_options = self ._storage_options ,
554
572
)
555
573
entry = manifest [unranked_path ]
556
574
if isinstance (entry , PrimitiveEntry ):
@@ -848,12 +866,14 @@ def __init__(
848
866
metadata : SnapshotMetadata ,
849
867
storage : StoragePlugin ,
850
868
event_loop : asyncio .AbstractEventLoop ,
869
+ storage_options : Optional [Dict [str , Any ]] = None ,
851
870
) -> None :
852
871
self .path = path
853
872
self .pg : Optional [dist .ProcessGroup ] = pg_wrapper .pg
854
873
# pyre-ignore
855
874
self .exc_info : Optional [Any ] = None
856
875
self ._done = False
876
+ self ._storage_options = storage_options
857
877
858
878
self .thread = Thread (
859
879
target = self ._complete_snapshot ,
@@ -921,7 +941,9 @@ def wait(self) -> Snapshot:
921
941
raise RuntimeError (
922
942
f"Encountered exception while taking snapshot asynchronously:\n { formatted } "
923
943
)
924
- return Snapshot (path = self .path , pg = self .pg )
944
+ return Snapshot (
945
+ path = self .path , pg = self .pg , storage_options = self ._storage_options
946
+ )
925
947
926
948
def done (self ) -> bool :
927
949
return self ._done
0 commit comments