Skip to content

Add storage_options to Snapshot StoragePlugin #108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions torchsnapshot/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def __init__(
self,
path: str,
pg: Optional[dist.ProcessGroup] = None,
storage_options: Optional[Dict[str, Any]] = None,
) -> None:
"""
Initializes the reference to an existing snapshot.
Expand All @@ -165,10 +166,13 @@ def __init__(
When unspecified:
- If distributed is initialized, the global process group will be used.
- If distributed is not initialized, single process is assumed.
storage_options: Additional keyword options for the StoragePlugin to use.
See each StoragePlugin's documentation for customizations.
"""
self.path: str = path
self.pg: Optional[dist.ProcessGroup] = pg
self._metadata: Optional[SnapshotMetadata] = None
self._storage_options = storage_options

@classmethod
def take(
Expand All @@ -177,6 +181,7 @@ def take(
app_state: AppState,
pg: Optional[dist.ProcessGroup] = None,
replicated: Optional[List[str]] = None,
storage_options: Optional[Dict[str, Any]] = None,
_custom_tensor_prepare_func: Optional[
Callable[[str, torch.Tensor, bool], torch.Tensor]
] = None,
Expand All @@ -194,6 +199,8 @@ def take(
replicated: A list of glob patterns for hinting the matching paths
as replicated. Note that patterns not specified by all ranks
are ignored.
storage_options: Additional keyword options for the StoragePlugin to use.
See each StoragePlugin's documentation for customizations.

Returns:
The newly taken snapshot.
Expand All @@ -211,7 +218,7 @@ def take(
replicated=replicated or [],
)
storage = url_to_storage_plugin_in_event_loop(
url_path=path, event_loop=event_loop
url_path=path, event_loop=event_loop, storage_options=storage_options
)
pending_io_work, metadata = cls._take_impl(
path=path,
Expand All @@ -236,7 +243,7 @@ def take(

storage.sync_close(event_loop=event_loop)
event_loop.close()
snapshot = cls(path=path, pg=pg)
snapshot = cls(path=path, pg=pg, storage_options=storage_options)
snapshot._metadata = metadata
return snapshot

Expand All @@ -247,6 +254,7 @@ def async_take(
app_state: AppState,
pg: Optional[dist.ProcessGroup] = None,
replicated: Optional[List[str]] = None,
storage_options: Optional[Dict[str, Any]] = None,
_custom_tensor_prepare_func: Optional[
Callable[[str, torch.Tensor, bool], torch.Tensor]
] = None,
Expand All @@ -269,6 +277,8 @@ def async_take(
replicated: A list of glob patterns for hinting the matching paths
as replicated. Note that patterns not specified by all ranks
are ignored.
storage_options: Additional keyword options for the StoragePlugin to use.
See each StoragePlugin's documentation for customizations.

Returns:
A handle with which the newly taken snapshot can be obtained via
Expand All @@ -288,7 +298,7 @@ def async_take(
replicated=replicated or [],
)
storage = url_to_storage_plugin_in_event_loop(
url_path=path, event_loop=event_loop
url_path=path, event_loop=event_loop, storage_options=storage_options
)

pending_io_work, metadata = cls._take_impl(
Expand All @@ -309,6 +319,7 @@ def async_take(
metadata=metadata,
storage=storage,
event_loop=event_loop,
storage_options=storage_options,
)

@classmethod
Expand Down Expand Up @@ -437,6 +448,7 @@ def restore(self, app_state: AppState) -> None:

Args:
app_state: The program state to restore from the snapshot.

"""
torch._C._log_api_usage_once("torchsnapshot.Snapshot.restore")
self._validate_app_state(app_state)
Expand All @@ -445,7 +457,9 @@ def restore(self, app_state: AppState) -> None:
pg_wrapper = PGWrapper(self.pg)
rank = pg_wrapper.get_rank()
storage = url_to_storage_plugin_in_event_loop(
url_path=self.path, event_loop=event_loop
url_path=self.path,
event_loop=event_loop,
storage_options=self._storage_options,
)

app_state = app_state.copy()
Expand Down Expand Up @@ -487,7 +501,9 @@ def metadata(self) -> SnapshotMetadata:
if self._metadata is None:
event_loop = asyncio.new_event_loop()
storage = url_to_storage_plugin_in_event_loop(
url_path=self.path, event_loop=event_loop
url_path=self.path,
event_loop=event_loop,
storage_options=self._storage_options,
)
self._metadata = self._read_snapshot_metadata(
storage=storage, event_loop=event_loop
Expand Down Expand Up @@ -557,7 +573,9 @@ def read_object(
event_loop = asyncio.new_event_loop()
pg_wrapper = PGWrapper(self.pg)
storage = url_to_storage_plugin_in_event_loop(
url_path=self.path, event_loop=event_loop
url_path=self.path,
event_loop=event_loop,
storage_options=self._storage_options,
)
entry = manifest[unranked_path]
if isinstance(entry, PrimitiveEntry):
Expand Down Expand Up @@ -855,12 +873,14 @@ def __init__(
metadata: SnapshotMetadata,
storage: StoragePlugin,
event_loop: asyncio.AbstractEventLoop,
storage_options: Optional[Dict[str, Any]] = None,
) -> None:
self.path = path
self.pg: Optional[dist.ProcessGroup] = pg_wrapper.pg
# pyre-ignore
self.exc_info: Optional[Any] = None
self._done = False
self._storage_options = storage_options

self.thread = Thread(
target=self._complete_snapshot,
Expand Down Expand Up @@ -928,7 +948,9 @@ def wait(self) -> Snapshot:
raise RuntimeError(
f"Encountered exception while taking snapshot asynchronously:\n{formatted}"
)
return Snapshot(path=self.path, pg=self.pg)
return Snapshot(
path=self.path, pg=self.pg, storage_options=self._storage_options
)

def done(self) -> bool:
return self._done
26 changes: 18 additions & 8 deletions torchsnapshot/storage_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# LICENSE file in the root directory of this source tree.

import asyncio
from typing import Any, Dict, Optional

from importlib_metadata import entry_points

Expand All @@ -14,13 +15,17 @@
from .storage_plugins.s3 import S3StoragePlugin


def url_to_storage_plugin(url_path: str) -> StoragePlugin:
def url_to_storage_plugin(
url_path: str, storage_options: Optional[Dict[str, Any]] = None
) -> StoragePlugin:
"""
Initialize storage plugin from url path.

Args:
url_path: The url path following the pattern [protocol]://[path].
The protocol defaults to `fs` if unspecified.
storage_options: Additional keyword options for the StoragePlugin to use.
See each StoragePlugin's documentation for customizations.

Returns:
The initialized storage plugin.
Expand All @@ -32,15 +37,18 @@ def url_to_storage_plugin(url_path: str) -> StoragePlugin:
else:
protocol, path = "fs", url_path

if storage_options is None:
storage_options = dict()

# Built-in storage plugins
if protocol == "fs":
return FSStoragePlugin(root=path)
return FSStoragePlugin(root=path, storage_options=storage_options)
elif protocol == "s3":
return S3StoragePlugin(root=path)
return S3StoragePlugin(root=path, storage_options=storage_options)
elif protocol == "gs":
from torchsnapshot.storage_plugins.gcs import GCSStoragePlugin

return GCSStoragePlugin(root=path)
return GCSStoragePlugin(root=path, storage_options=storage_options)

# Registered storage plugins
eps = entry_points(group="storage_plugins")
Expand All @@ -60,9 +68,11 @@ def url_to_storage_plugin(url_path: str) -> StoragePlugin:


def url_to_storage_plugin_in_event_loop(
url_path: str, event_loop: asyncio.AbstractEventLoop
url_path: str,
event_loop: asyncio.AbstractEventLoop,
storage_options: Optional[Dict[str, Any]] = None,
) -> StoragePlugin:
async def _url_to_storage_plugin(url_path: str) -> StoragePlugin:
return url_to_storage_plugin(url_path=url_path)
async def _url_to_storage_plugin() -> StoragePlugin:
return url_to_storage_plugin(url_path=url_path, storage_options=storage_options)

return event_loop.run_until_complete(_url_to_storage_plugin(url_path=url_path))
return event_loop.run_until_complete(_url_to_storage_plugin())
6 changes: 4 additions & 2 deletions torchsnapshot/storage_plugins/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import io
import os
import pathlib
from typing import Set
from typing import Any, Dict, Optional, Set

import aiofiles
import aiofiles.os
Expand All @@ -17,7 +17,9 @@


class FSStoragePlugin(StoragePlugin):
def __init__(self, root: str) -> None:
def __init__(
self, root: str, storage_options: Optional[Dict[str, Any]] = None
) -> None:
self.root = root
self._dir_cache: Set[pathlib.Path] = set()

Expand Down
6 changes: 4 additions & 2 deletions torchsnapshot/storage_plugins/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import random
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Awaitable, Callable, Optional, TypeVar
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar
from urllib.parse import quote

import google.auth.exceptions # @manual
Expand Down Expand Up @@ -59,7 +59,9 @@ class GCSStoragePlugin(StoragePlugin):
"{bucket}/o/{blob_name}?alt=media"
)

def __init__(self, root: str) -> None:
def __init__(
self, root: str, storage_options: Optional[Dict[str, Any]] = None
) -> None:
components = root.split("/")
if len(components) < 2:
raise RuntimeError(
Expand Down
6 changes: 5 additions & 1 deletion torchsnapshot/storage_plugins/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@

import io
import os
from typing import Any, Dict, Optional

from torchsnapshot.io_types import ReadIO, StoragePlugin, WriteIO
from torchsnapshot.memoryview_stream import MemoryviewStream


class S3StoragePlugin(StoragePlugin):
def __init__(self, root: str) -> None:
def __init__(
self, root: str, storage_options: Optional[Dict[str, Any]] = None
) -> None:
try:
from aiobotocore.session import get_session # @manual
except ImportError:
Expand All @@ -30,6 +33,7 @@ def __init__(self, root: str) -> None:
self.bucket: str = components[0]
self.root: str = "/".join(components[1:])
# pyre-ignore
# TODO: read AWS tokens from storage_options?
self.session = get_session()

async def write(self, write_io: WriteIO) -> None:
Expand Down