Skip to content

fix: boto3 session options #604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,12 @@ aws_storage_options={
"aws_access_key_id": os.environ['AWS_ACCESS_KEY_ID'],
"aws_secret_access_key": os.environ['AWS_SECRET_ACCESS_KEY'],
}
dataset = ld.StreamingDataset("s3://my-bucket/my-data", storage_options=aws_storage_options)
# You can also pass the session options. (for boto3 only)
aws_session_options = {
"profile_name": os.environ['AWS_PROFILE_NAME'], # Required only for custom profiles
"region_name": os.environ['AWS_REGION_NAME'], # Required only for custom regions
}
dataset = ld.StreamingDataset("s3://my-bucket/my-data", storage_options=aws_storage_options, session_options=aws_session_options)


# Read data from GCS
Expand Down
3 changes: 3 additions & 0 deletions src/litdata/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
serializers: Optional[Dict[str, Serializer]] = None,
writer_chunk_index: Optional[int] = None,
storage_options: Optional[Dict] = {},
session_options: Optional[Dict] = {},
max_pre_download: int = 2,
):
"""The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements
Expand All @@ -64,6 +65,7 @@ def __init__(
serializers: Provide your own serializers.
writer_chunk_index: The index of the chunk to start from when writing.
storage_options: Additional connection options for accessing storage services.
session_options: Additional options for the S3 session.
max_pre_download: Maximum number of chunks that can be pre-downloaded while filling up the cache.

"""
Expand Down Expand Up @@ -92,6 +94,7 @@ def __init__(
item_loader=item_loader,
serializers=serializers,
storage_options=storage_options,
session_options=session_options,
max_pre_download=max_pre_download,
)
self._is_done = False
Expand Down
12 changes: 9 additions & 3 deletions src/litdata/streaming/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,25 @@
class S3Client:
# TODO: Generalize to support more cloud providers.

def __init__(self, refetch_interval: int = 3300, storage_options: Optional[Dict] = {}) -> None:
def __init__(
self,
refetch_interval: int = 3300,
storage_options: Optional[Dict] = {},
session_options: Optional[Dict] = {},
) -> None:
self._refetch_interval = refetch_interval
self._last_time: Optional[float] = None
self._client: Optional[Any] = None
self._storage_options: dict = storage_options or {}
self._session_options: dict = session_options or {}

def _create_client(self) -> None:
has_shared_credentials_file = (
os.getenv("AWS_SHARED_CREDENTIALS_FILE") == os.getenv("AWS_CONFIG_FILE") == "/.credentials/.aws_credentials"
)

if has_shared_credentials_file or not _IS_IN_STUDIO or self._storage_options:
session = boto3.Session()
if has_shared_credentials_file or not _IS_IN_STUDIO or self._storage_options or self._session_options:
session = boto3.Session(**self._session_options) # If additional options are provided
self._client = session.client(
"s3",
**{
Expand Down
19 changes: 16 additions & 3 deletions src/litdata/streaming/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
subsampled_files: Optional[List[str]] = None,
region_of_interest: Optional[List[Tuple[int, int]]] = None,
storage_options: Optional[Dict] = {},
session_options: Optional[Dict] = {},
) -> None:
"""Reads the index files associated a chunked dataset and enables to map an index to its chunk.

Expand All @@ -51,6 +52,7 @@ def __init__(
subsampled_files: List of subsampled chunk files loaded from `input_dir/index.json` file.
region_of_interest: List of tuples of {start,end} of region of interest for each chunk.
storage_options: Additional connection options for accessing storage services.
session_options: Additional options for S3 session.

"""
self._cache_dir = cache_dir
Expand All @@ -60,6 +62,7 @@ def __init__(
self._remote_dir = remote_dir
self._item_loader = item_loader or PyTreeLoader()
self._storage_options = storage_options
self._session_options = session_options

# load data from `index.json` file
data = load_index_file(self._cache_dir)
Expand All @@ -84,7 +87,9 @@ def __init__(
self._downloader = None

if remote_dir:
self._downloader = get_downloader(remote_dir, cache_dir, self._chunks, self._storage_options)
self._downloader = get_downloader(
remote_dir, cache_dir, self._chunks, self._storage_options, self._session_options
)

self._compressor_name = self._config["compression"]
self._compressor: Optional[Compressor] = None
Expand Down Expand Up @@ -286,6 +291,7 @@ def load(
subsampled_files: Optional[List[str]] = None,
region_of_interest: Optional[List[Tuple[int, int]]] = None,
storage_options: Optional[dict] = {},
session_options: Optional[dict] = {},
) -> Optional["ChunksConfig"]:
cache_index_filepath = os.path.join(cache_dir, _INDEX_FILENAME)

Expand All @@ -298,14 +304,21 @@ def load(
f"This should not have happened. No index.json file found in cache: {cache_index_filepath}"
)
else:
downloader = get_downloader(remote_dir, cache_dir, [], storage_options)
downloader = get_downloader(remote_dir, cache_dir, [], storage_options, session_options)
downloader.download_file(os.path.join(remote_dir, _INDEX_FILENAME), cache_index_filepath)

if not os.path.exists(cache_index_filepath):
return None

return ChunksConfig(
cache_dir, serializers, remote_dir, item_loader, subsampled_files, region_of_interest, storage_options
cache_dir,
serializers,
remote_dir,
item_loader,
subsampled_files,
region_of_interest,
storage_options,
session_options,
)

def __len__(self) -> int:
Expand Down
5 changes: 5 additions & 0 deletions src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
subsample: float = 1.0,
encryption: Optional[Encryption] = None,
storage_options: Optional[Dict] = {},
session_options: Optional[Dict] = {},
max_pre_download: int = 2,
index_path: Optional[str] = None,
force_override_state_dict: bool = False,
Expand All @@ -81,6 +82,7 @@ def __init__(
subsample: Float representing fraction of the dataset to be randomly sampled (e.g., 0.1 => 10% of dataset).
encryption: The encryption object to use for decrypting the data.
storage_options: Additional connection options for accessing storage services.
session_options: Additional connection options for accessing S3 services.
max_pre_download: Maximum number of chunks that can be pre-downloaded by the StreamingDataset.
index_path: Path to `index.json` for the Parquet dataset.
If `index_path` is a directory, the function will look for `index.json` within it.
Expand Down Expand Up @@ -128,6 +130,7 @@ def __init__(
shuffle,
seed,
storage_options,
session_options,
index_path,
fnmatch_pattern,
)
Expand Down Expand Up @@ -190,6 +193,7 @@ def __init__(
self.batch_size: int = 1
self._encryption = encryption
self.storage_options = storage_options
self.session_options = session_options
self.max_pre_download = max_pre_download

def set_shuffle(self, shuffle: bool) -> None:
Expand Down Expand Up @@ -228,6 +232,7 @@ def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
max_cache_size=self.max_cache_size,
encryption=self._encryption,
storage_options=self.storage_options,
session_options=self.session_options,
max_pre_download=self.max_pre_download,
)
cache._reader._try_load_config()
Expand Down
48 changes: 40 additions & 8 deletions src/litdata/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@

class Downloader(ABC):
def __init__(
self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
self,
remote_dir: str,
cache_dir: str,
chunks: List[Dict[str, Any]],
storage_options: Optional[Dict] = {},
**kwargs: Any,
):
self._remote_dir = remote_dir
self._cache_dir = cache_dir
Expand Down Expand Up @@ -77,13 +82,20 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None:

class S3Downloader(Downloader):
def __init__(
self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
self,
remote_dir: str,
cache_dir: str,
chunks: List[Dict[str, Any]],
storage_options: Optional[Dict] = {},
**kwargs: Any,
):
super().__init__(remote_dir, cache_dir, chunks, storage_options)
self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0
# check if kwargs contains session_options
self.session_options = kwargs.get("session_options", {})

if not self._s5cmd_available or _DISABLE_S5CMD:
self._client = S3Client(storage_options=self._storage_options)
self._client = S3Client(storage_options=self._storage_options, session_options=self.session_options)

def download_file(self, remote_filepath: str, local_filepath: str) -> None:
obj = parse.urlparse(remote_filepath)
Expand Down Expand Up @@ -156,7 +168,12 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:

class GCPDownloader(Downloader):
def __init__(
self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
self,
remote_dir: str,
cache_dir: str,
chunks: List[Dict[str, Any]],
storage_options: Optional[Dict] = {},
**kwargs: Any,
):
if not _GOOGLE_STORAGE_AVAILABLE:
raise ModuleNotFoundError(str(_GOOGLE_STORAGE_AVAILABLE))
Expand Down Expand Up @@ -194,7 +211,12 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:

class AzureDownloader(Downloader):
def __init__(
self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
self,
remote_dir: str,
cache_dir: str,
chunks: List[Dict[str, Any]],
storage_options: Optional[Dict] = {},
**kwargs: Any,
):
if not _AZURE_STORAGE_AVAILABLE:
raise ModuleNotFoundError(str(_AZURE_STORAGE_AVAILABLE))
Expand Down Expand Up @@ -247,7 +269,12 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:

class HFDownloader(Downloader):
def __init__(
self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
self,
remote_dir: str,
cache_dir: str,
chunks: List[Dict[str, Any]],
storage_options: Optional[Dict] = {},
**kwargs: Any,
):
if not _HF_HUB_AVAILABLE:
raise ModuleNotFoundError(
Expand Down Expand Up @@ -331,7 +358,11 @@ def unregister_downloader(prefix: str) -> None:


def get_downloader(
remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
remote_dir: str,
cache_dir: str,
chunks: List[Dict[str, Any]],
storage_options: Optional[Dict] = {},
session_options: Optional[Dict] = {},
) -> Downloader:
"""Get the appropriate downloader instance based on the remote directory prefix.

Expand All @@ -340,13 +371,14 @@ def get_downloader(
cache_dir (str): The local cache directory.
chunks (List[Dict[str, Any]]): List of chunks to managed by the downloader.
storage_options (Optional[Dict], optional): Additional storage options. Defaults to {}.
session_options (Optional[Dict], optional): Additional S3 session options. Defaults to {}.

Returns:
Downloader: An instance of the appropriate downloader class.
"""
for k, cls in _DOWNLOADERS.items():
if str(remote_dir).startswith(k):
return cls(remote_dir, cache_dir, chunks, storage_options)
return cls(remote_dir, cache_dir, chunks, storage_options, session_options=session_options)
else:
# Default to LocalDownloader if no prefix is matched
return LocalDownloader(remote_dir, cache_dir, chunks, storage_options)
4 changes: 4 additions & 0 deletions src/litdata/streaming/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def __init__(
item_loader: Optional[BaseItemLoader] = None,
serializers: Optional[Dict[str, Serializer]] = None,
storage_options: Optional[dict] = {},
session_options: Optional[dict] = {},
max_pre_download: int = 2,
) -> None:
"""The BinaryReader enables to read chunked dataset in an efficient way.
Expand All @@ -281,6 +282,7 @@ def __init__(
max_cache_size: The maximum cache size used by the reader when fetching the chunks.
serializers: Provide your own serializers.
storage_options: Additional connection options for accessing storage services.
session_options: Additional options for the S3 session.
max_pre_download: Maximum number of chunks that can be pre-downloaded by the reader.

"""
Expand Down Expand Up @@ -308,6 +310,7 @@ def __init__(
self._chunks_queued_for_download = False
self._max_cache_size = int(os.getenv("MAX_CACHE_SIZE", max_cache_size or 0))
self._storage_options = storage_options
self._session_options = session_options
self._max_pre_download = max_pre_download

def _get_chunk_index_from_index(self, index: int) -> Tuple[int, int]:
Expand All @@ -327,6 +330,7 @@ def _try_load_config(self) -> Optional[ChunksConfig]:
self.subsampled_files,
self.region_of_interest,
self._storage_options,
self._session_options,
)
return self._config

Expand Down
10 changes: 7 additions & 3 deletions src/litdata/utilities/dataset_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def subsample_streaming_dataset(
shuffle: bool = False,
seed: int = 42,
storage_options: Optional[Dict] = {},
session_options: Optional[Dict] = {},
index_path: Optional[str] = None,
fnmatch_pattern: Optional[str] = None,
) -> Tuple[List[str], List[Tuple[int, int]]]:
Expand All @@ -46,6 +47,7 @@ def subsample_streaming_dataset(
input_dir=input_dir.path if input_dir.path else input_dir.url,
cache_dir=cache_dir.path if cache_dir else None,
storage_options=storage_options,
session_options=session_options,
index_path=index_path,
)
if cache_path is not None:
Expand All @@ -61,7 +63,7 @@ def subsample_streaming_dataset(
if index_path is not None:
copy_index_to_cache_index_filepath(index_path, cache_index_filepath)
else:
downloader = get_downloader(input_dir.url, input_dir.path, [], storage_options)
downloader = get_downloader(input_dir.url, input_dir.path, [], storage_options, session_options)
downloader.download_file(os.path.join(input_dir.url, _INDEX_FILENAME), cache_index_filepath)

time.sleep(0.5) # Give some time for the file to be available
Expand Down Expand Up @@ -141,6 +143,7 @@ def _should_replace_path(path: Optional[str]) -> bool:
def _read_updated_at(
input_dir: Optional[Dir],
storage_options: Optional[Dict] = {},
session_options: Optional[Dict] = {},
index_path: Optional[str] = None,
) -> str:
"""Read last updated timestamp from index.json file."""
Expand All @@ -160,7 +163,7 @@ def _read_updated_at(
if index_path is not None:
copy_index_to_cache_index_filepath(index_path, temp_index_filepath)
else:
downloader = get_downloader(input_dir.url, tmp_directory, [], storage_options)
downloader = get_downloader(input_dir.url, tmp_directory, [], storage_options, session_options)
downloader.download_file(os.path.join(input_dir.url, _INDEX_FILENAME), temp_index_filepath)
index_json_content = load_index_file(tmp_directory)

Expand Down Expand Up @@ -213,11 +216,12 @@ def _try_create_cache_dir(
input_dir: Optional[str],
cache_dir: Optional[str] = None,
storage_options: Optional[Dict] = {},
session_options: Optional[Dict] = {},
index_path: Optional[str] = None,
) -> Optional[str]:
"""Prepare and return the cache directory for a dataset."""
resolved_input_dir = _resolve_dir(input_dir)
updated_at = _read_updated_at(resolved_input_dir, storage_options, index_path)
updated_at = _read_updated_at(resolved_input_dir, storage_options, session_options, index_path)

# Fallback to a hash of the input_dir if updated_at is "0"
if updated_at == "0" and input_dir is not None:
Expand Down
Loading