Skip to content

Commit 7fd1ece

Browse files
reyoungfacebook-github-bot
authored andcommitted
Add storage_options to Snapshot StoragePlugin (#108)
Summary: Pull Request resolved: #108 Reviewed By: ananthsub Differential Revision: D40669655 Pulled By: yifuwang fbshipit-source-id: e0b9f0d36d081c1e18fe60568d3b84eef9e49adb
1 parent 5a15771 commit 7fd1ece

File tree

5 files changed

+61
-21
lines changed

5 files changed

+61
-21
lines changed

torchsnapshot/snapshot.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def __init__(
148148
self,
149149
path: str,
150150
pg: Optional[dist.ProcessGroup] = None,
151+
storage_options: Optional[Dict[str, Any]] = None,
151152
) -> None:
152153
"""
153154
Initializes the reference to an existing snapshot.
@@ -158,10 +159,13 @@ def __init__(
158159
When unspecified:
159160
- If distributed is initialized, the global process group will be used.
160161
- If distributed is not initialized, single process is assumed.
162+
storage_options: Additional keyword options for the StoragePlugin to use.
163+
See each StoragePlugin's documentation for customizations.
161164
"""
162165
self.path: str = path
163166
self.pg: Optional[dist.ProcessGroup] = pg
164167
self._metadata: Optional[SnapshotMetadata] = None
168+
self._storage_options = storage_options
165169

166170
@classmethod
167171
def take(
@@ -170,6 +174,7 @@ def take(
170174
app_state: AppState,
171175
pg: Optional[dist.ProcessGroup] = None,
172176
replicated: Optional[List[str]] = None,
177+
storage_options: Optional[Dict[str, Any]] = None,
173178
_custom_tensor_prepare_func: Optional[
174179
Callable[[str, torch.Tensor, bool], torch.Tensor]
175180
] = None,
@@ -187,6 +192,8 @@ def take(
187192
replicated: A list of glob patterns for hinting the matching paths
188193
as replicated. Note that patterns not specified by all ranks
189194
are ignored.
195+
storage_options: Additional keyword options for the StoragePlugin to use.
196+
See each StoragePlugin's documentation for customizations.
190197
191198
Returns:
192199
The newly taken snapshot.
@@ -204,7 +211,7 @@ def take(
204211
replicated=replicated or [],
205212
)
206213
storage = url_to_storage_plugin_in_event_loop(
207-
url_path=path, event_loop=event_loop
214+
url_path=path, event_loop=event_loop, storage_options=storage_options
208215
)
209216
pending_io_work, metadata = cls._take_impl(
210217
path=path,
@@ -229,7 +236,7 @@ def take(
229236

230237
storage.sync_close(event_loop=event_loop)
231238
event_loop.close()
232-
snapshot = cls(path=path, pg=pg)
239+
snapshot = cls(path=path, pg=pg, storage_options=storage_options)
233240
snapshot._metadata = metadata
234241
return snapshot
235242

@@ -240,6 +247,7 @@ def async_take(
240247
app_state: AppState,
241248
pg: Optional[dist.ProcessGroup] = None,
242249
replicated: Optional[List[str]] = None,
250+
storage_options: Optional[Dict[str, Any]] = None,
243251
_custom_tensor_prepare_func: Optional[
244252
Callable[[str, torch.Tensor, bool], torch.Tensor]
245253
] = None,
@@ -262,6 +270,8 @@ def async_take(
262270
replicated: A list of glob patterns for hinting the matching paths
263271
as replicated. Note that patterns not specified by all ranks
264272
are ignored.
273+
storage_options: Additional keyword options for the StoragePlugin to use.
274+
See each StoragePlugin's documentation for customizations.
265275
266276
Returns:
267277
A handle with which the newly taken snapshot can be obtained via
@@ -281,7 +291,7 @@ def async_take(
281291
replicated=replicated or [],
282292
)
283293
storage = url_to_storage_plugin_in_event_loop(
284-
url_path=path, event_loop=event_loop
294+
url_path=path, event_loop=event_loop, storage_options=storage_options
285295
)
286296

287297
pending_io_work, metadata = cls._take_impl(
@@ -302,6 +312,7 @@ def async_take(
302312
metadata=metadata,
303313
storage=storage,
304314
event_loop=event_loop,
315+
storage_options=storage_options,
305316
)
306317

307318
@classmethod
@@ -430,6 +441,7 @@ def restore(self, app_state: AppState) -> None:
430441
431442
Args:
432443
app_state: The program state to restore from the snapshot.
444+
433445
"""
434446
torch._C._log_api_usage_once("torchsnapshot.Snapshot.restore")
435447
self._validate_app_state(app_state)
@@ -438,7 +450,9 @@ def restore(self, app_state: AppState) -> None:
438450
pg_wrapper = PGWrapper(self.pg)
439451
rank = pg_wrapper.get_rank()
440452
storage = url_to_storage_plugin_in_event_loop(
441-
url_path=self.path, event_loop=event_loop
453+
url_path=self.path,
454+
event_loop=event_loop,
455+
storage_options=self._storage_options,
442456
)
443457

444458
app_state = app_state.copy()
@@ -480,7 +494,9 @@ def metadata(self) -> SnapshotMetadata:
480494
if self._metadata is None:
481495
event_loop = asyncio.new_event_loop()
482496
storage = url_to_storage_plugin_in_event_loop(
483-
url_path=self.path, event_loop=event_loop
497+
url_path=self.path,
498+
event_loop=event_loop,
499+
storage_options=self._storage_options,
484500
)
485501
self._metadata = self._read_snapshot_metadata(
486502
storage=storage, event_loop=event_loop
@@ -550,7 +566,9 @@ def read_object(
550566
event_loop = asyncio.new_event_loop()
551567
pg_wrapper = PGWrapper(self.pg)
552568
storage = url_to_storage_plugin_in_event_loop(
553-
url_path=self.path, event_loop=event_loop
569+
url_path=self.path,
570+
event_loop=event_loop,
571+
storage_options=self._storage_options,
554572
)
555573
entry = manifest[unranked_path]
556574
if isinstance(entry, PrimitiveEntry):
@@ -848,12 +866,14 @@ def __init__(
848866
metadata: SnapshotMetadata,
849867
storage: StoragePlugin,
850868
event_loop: asyncio.AbstractEventLoop,
869+
storage_options: Optional[Dict[str, Any]] = None,
851870
) -> None:
852871
self.path = path
853872
self.pg: Optional[dist.ProcessGroup] = pg_wrapper.pg
854873
# pyre-ignore
855874
self.exc_info: Optional[Any] = None
856875
self._done = False
876+
self._storage_options = storage_options
857877

858878
self.thread = Thread(
859879
target=self._complete_snapshot,
@@ -921,7 +941,9 @@ def wait(self) -> Snapshot:
921941
raise RuntimeError(
922942
f"Encountered exception while taking snapshot asynchronously:\n{formatted}"
923943
)
924-
return Snapshot(path=self.path, pg=self.pg)
944+
return Snapshot(
945+
path=self.path, pg=self.pg, storage_options=self._storage_options
946+
)
925947

926948
def done(self) -> bool:
927949
return self._done

torchsnapshot/storage_plugin.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# LICENSE file in the root directory of this source tree.
77

88
import asyncio
9+
from typing import Any, Dict, Optional
910

1011
from importlib_metadata import entry_points
1112

@@ -14,13 +15,17 @@
1415
from .storage_plugins.s3 import S3StoragePlugin
1516

1617

17-
def url_to_storage_plugin(url_path: str) -> StoragePlugin:
18+
def url_to_storage_plugin(
19+
url_path: str, storage_options: Optional[Dict[str, Any]] = None
20+
) -> StoragePlugin:
1821
"""
1922
Initialize storage plugin from url path.
2023
2124
Args:
2225
url_path: The url path following the pattern [protocol]://[path].
2326
The protocol defaults to `fs` if unspecified.
27+
storage_options: Additional keyword options for the StoragePlugin to use.
28+
See each StoragePlugin's documentation for customizations.
2429
2530
Returns:
2631
The initialized storage plugin.
@@ -32,23 +37,26 @@ def url_to_storage_plugin(url_path: str) -> StoragePlugin:
3237
else:
3338
protocol, path = "fs", url_path
3439

40+
if storage_options is None:
41+
storage_options = {}
42+
3543
# Built-in storage plugins
3644
if protocol == "fs":
37-
return FSStoragePlugin(root=path)
45+
return FSStoragePlugin(root=path, storage_options=storage_options)
3846
elif protocol == "s3":
39-
return S3StoragePlugin(root=path)
47+
return S3StoragePlugin(root=path, storage_options=storage_options)
4048
elif protocol == "gs":
4149
from torchsnapshot.storage_plugins.gcs import GCSStoragePlugin
4250

43-
return GCSStoragePlugin(root=path)
51+
return GCSStoragePlugin(root=path, storage_options=storage_options)
4452

4553
# Registered storage plugins
4654
eps = entry_points(group="storage_plugins")
4755
registered_plugins = {ep.name: ep for ep in eps}
4856
if protocol in registered_plugins:
4957
entry = registered_plugins[protocol]
5058
factory = entry.load()
51-
plugin = factory(path)
59+
plugin = factory(path, storage_options)
5260
if not isinstance(plugin, StoragePlugin):
5361
raise RuntimeError(
5462
f"The factory function for {protocol} ({entry.value}) "
@@ -60,9 +68,11 @@ def url_to_storage_plugin(url_path: str) -> StoragePlugin:
6068

6169

6270
def url_to_storage_plugin_in_event_loop(
63-
url_path: str, event_loop: asyncio.AbstractEventLoop
71+
url_path: str,
72+
event_loop: asyncio.AbstractEventLoop,
73+
storage_options: Optional[Dict[str, Any]] = None,
6474
) -> StoragePlugin:
65-
async def _url_to_storage_plugin(url_path: str) -> StoragePlugin:
66-
return url_to_storage_plugin(url_path=url_path)
75+
async def _url_to_storage_plugin() -> StoragePlugin:
76+
return url_to_storage_plugin(url_path=url_path, storage_options=storage_options)
6777

68-
return event_loop.run_until_complete(_url_to_storage_plugin(url_path=url_path))
78+
return event_loop.run_until_complete(_url_to_storage_plugin())

torchsnapshot/storage_plugins/fs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import io
99
import os
1010
import pathlib
11-
from typing import Set
11+
from typing import Any, Dict, Optional, Set
1212

1313
import aiofiles
1414
import aiofiles.os
@@ -17,7 +17,9 @@
1717

1818

1919
class FSStoragePlugin(StoragePlugin):
20-
def __init__(self, root: str) -> None:
20+
def __init__(
21+
self, root: str, storage_options: Optional[Dict[str, Any]] = None
22+
) -> None:
2123
self.root = root
2224
self._dir_cache: Set[pathlib.Path] = set()
2325

torchsnapshot/storage_plugins/gcs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import random
1717
import time
1818
from concurrent.futures import ThreadPoolExecutor
19-
from typing import Any, Awaitable, Callable, Optional, TypeVar
19+
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar
2020
from urllib.parse import quote
2121

2222
import google.auth.exceptions # @manual
@@ -59,7 +59,9 @@ class GCSStoragePlugin(StoragePlugin):
5959
"{bucket}/o/{blob_name}?alt=media"
6060
)
6161

62-
def __init__(self, root: str) -> None:
62+
def __init__(
63+
self, root: str, storage_options: Optional[Dict[str, Any]] = None
64+
) -> None:
6365
components = root.split("/")
6466
if len(components) < 2:
6567
raise RuntimeError(

torchsnapshot/storage_plugins/s3.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77

88
import io
99
import os
10+
from typing import Any, Dict, Optional
1011

1112
from torchsnapshot.io_types import ReadIO, StoragePlugin, WriteIO
1213
from torchsnapshot.memoryview_stream import MemoryviewStream
1314

1415

1516
class S3StoragePlugin(StoragePlugin):
16-
def __init__(self, root: str) -> None:
17+
def __init__(
18+
self, root: str, storage_options: Optional[Dict[str, Any]] = None
19+
) -> None:
1720
try:
1821
from aiobotocore.session import get_session # @manual
1922
except ImportError:
@@ -30,6 +33,7 @@ def __init__(self, root: str) -> None:
3033
self.bucket: str = components[0]
3134
self.root: str = "/".join(components[1:])
3235
# pyre-ignore
36+
# TODO: read AWS tokens from storage_options?
3337
self.session = get_session()
3438

3539
async def write(self, write_io: WriteIO) -> None:

0 commit comments

Comments
 (0)