Skip to content

Commit a2cfe7b

Browse files
csy1204tchaton
andauthored
feat: add a custom storage options param (#246)
Co-authored-by: thomas chaton <thomas@grid.ai>
1 parent fd411e5 commit a2cfe7b

File tree

10 files changed

+100
-19
lines changed

10 files changed

+100
-19
lines changed

README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,21 @@ for batch in dataloader:
209209

210210
```
211211

212+
213+
Additionally, you can inject client connection settings for [S3](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html#boto3.session.Session.client) or GCP when initializing your dataset. This is useful for specifying custom endpoints and credentials per dataset.
214+
215+
```python
216+
from litdata import StreamingDataset
217+
218+
storage_options = {
219+
"endpoint_url": "your_endpoint_url",
220+
"aws_access_key_id": "your_access_key_id",
221+
"aws_secret_access_key": "your_secret_access_key",
222+
}
223+
224+
dataset = StreamingDataset('s3://my-bucket/my-data', storage_options=storage_options)
225+
```
226+
212227
</details>
213228

214229
<details>

src/litdata/streaming/cache.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
max_cache_size: Union[int, str] = "100GB",
4646
serializers: Optional[Dict[str, Serializer]] = None,
4747
writer_chunk_index: Optional[int] = None,
48+
storage_options: Optional[Dict] = {},
4849
):
4950
"""The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements
5051
together in order to accelerate fetching.
@@ -60,6 +61,7 @@ def __init__(
6061
max_cache_size: The maximum cache size used by the reader when fetching the chunks.
6162
serializers: Provide your own serializers.
6263
writer_chunk_index: The index of the chunk to start from when writing.
64+
storage_options: Additional connection options for accessing storage services.
6365
6466
"""
6567
super().__init__()
@@ -85,6 +87,7 @@ def __init__(
8587
encryption=encryption,
8688
item_loader=item_loader,
8789
serializers=serializers,
90+
storage_options=storage_options,
8891
)
8992
self._is_done = False
9093
self._distributed_env = _DistributedEnv.detect()

src/litdata/streaming/client.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import os
1515
from time import time
16-
from typing import Any, Optional
16+
from typing import Any, Dict, Optional
1717

1818
import boto3
1919
import botocore
@@ -26,10 +26,11 @@
2626
class S3Client:
2727
# TODO: Generalize to support more cloud providers.
2828

29-
def __init__(self, refetch_interval: int = 3300) -> None:
29+
def __init__(self, refetch_interval: int = 3300, storage_options: Optional[Dict] = {}) -> None:
3030
self._refetch_interval = refetch_interval
3131
self._last_time: Optional[float] = None
3232
self._client: Optional[Any] = None
33+
self._storage_options: dict = storage_options or {}
3334

3435
def _create_client(self) -> None:
3536
has_shared_credentials_file = (
@@ -38,7 +39,11 @@ def _create_client(self) -> None:
3839

3940
if has_shared_credentials_file or not _IS_IN_STUDIO:
4041
self._client = boto3.client(
41-
"s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"})
42+
"s3",
43+
**{
44+
"config": botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}),
45+
**self._storage_options,
46+
},
4247
)
4348
else:
4449
provider = InstanceMetadataProvider(iam_role_fetcher=InstanceMetadataFetcher(timeout=3600, num_attempts=5))

src/litdata/streaming/config.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
item_loader: Optional[BaseItemLoader] = None,
3434
subsampled_files: Optional[List[str]] = None,
3535
region_of_interest: Optional[List[Tuple[int, int]]] = None,
36+
storage_options: Optional[Dict] = {},
3637
) -> None:
3738
"""The ChunksConfig reads the index files associated a chunked dataset and enables to map an index to its
3839
chunk.
@@ -44,6 +45,7 @@ def __init__(
4445
The scheme needs to be added to the path.
4546
subsampled_files: List of subsampled chunk files loaded from `input_dir/index.json` file.
4647
region_of_interest: List of tuples of {start,end} of region of interest for each chunk.
48+
storage_options: Additional connection options for accessing storage services.
4749
4850
"""
4951
self._cache_dir = cache_dir
@@ -52,6 +54,7 @@ def __init__(
5254
self._chunks = None
5355
self._remote_dir = remote_dir
5456
self._item_loader = item_loader or PyTreeLoader()
57+
self._storage_options = storage_options
5558

5659
# load data from `index.json` file
5760
data = load_index_file(self._cache_dir)
@@ -75,7 +78,7 @@ def __init__(
7578
self._downloader = None
7679

7780
if remote_dir:
78-
self._downloader = get_downloader_cls(remote_dir, cache_dir, self._chunks)
81+
self._downloader = get_downloader_cls(remote_dir, cache_dir, self._chunks, self._storage_options)
7982

8083
self._compressor_name = self._config["compression"]
8184
self._compressor: Optional[Compressor] = None
@@ -234,17 +237,20 @@ def load(
234237
item_loader: Optional[BaseItemLoader] = None,
235238
subsampled_files: Optional[List[str]] = None,
236239
region_of_interest: Optional[List[Tuple[int, int]]] = None,
240+
storage_options: Optional[dict] = {},
237241
) -> Optional["ChunksConfig"]:
238242
cache_index_filepath = os.path.join(cache_dir, _INDEX_FILENAME)
239243

240244
if isinstance(remote_dir, str):
241-
downloader = get_downloader_cls(remote_dir, cache_dir, [])
245+
downloader = get_downloader_cls(remote_dir, cache_dir, [], storage_options)
242246
downloader.download_file(os.path.join(remote_dir, _INDEX_FILENAME), cache_index_filepath)
243247

244248
if not os.path.exists(cache_index_filepath):
245249
return None
246250

247-
return ChunksConfig(cache_dir, serializers, remote_dir, item_loader, subsampled_files, region_of_interest)
251+
return ChunksConfig(
252+
cache_dir, serializers, remote_dir, item_loader, subsampled_files, region_of_interest, storage_options
253+
)
248254

249255
def __len__(self) -> int:
250256
return self._length

src/litdata/streaming/dataset.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(
5454
max_cache_size: Union[int, str] = "100GB",
5555
subsample: float = 1.0,
5656
encryption: Optional[Encryption] = None,
57+
storage_options: Optional[Dict] = {},
5758
) -> None:
5859
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
5960
@@ -70,6 +71,7 @@ def __init__(
7071
max_cache_size: The maximum cache size used by the StreamingDataset.
7172
subsample: Float representing fraction of the dataset to be randomly sampled (e.g., 0.1 => 10% of dataset).
7273
encryption: The encryption object to use for decrypting the data.
74+
storage_options: Additional connection options for accessing storage services.
7375
7476
"""
7577
super().__init__()
@@ -85,7 +87,7 @@ def __init__(
8587
self.subsampled_files: List[str] = []
8688
self.region_of_interest: List[Tuple[int, int]] = []
8789
self.subsampled_files, self.region_of_interest = subsample_streaming_dataset(
88-
self.input_dir, item_loader, subsample, shuffle, seed
90+
self.input_dir, item_loader, subsample, shuffle, seed, storage_options
8991
)
9092

9193
self.item_loader = item_loader
@@ -128,6 +130,7 @@ def __init__(
128130
self.num_workers: int = 1
129131
self.batch_size: int = 1
130132
self._encryption = encryption
133+
self.storage_options = storage_options
131134

132135
def set_shuffle(self, shuffle: bool) -> None:
133136
self.shuffle = shuffle
@@ -163,6 +166,7 @@ def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
163166
serializers=self.serializers,
164167
max_cache_size=self.max_cache_size,
165168
encryption=self._encryption,
169+
storage_options=self.storage_options,
166170
)
167171
cache._reader._try_load_config()
168172

src/litdata/streaming/downloader.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import shutil
1616
import subprocess
1717
from abc import ABC
18-
from typing import Any, Dict, List
18+
from typing import Any, Dict, List, Optional
1919
from urllib import parse
2020

2121
from filelock import FileLock, Timeout
@@ -25,10 +25,13 @@
2525

2626

2727
class Downloader(ABC):
28-
def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]):
28+
def __init__(
29+
self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
30+
):
2931
self._remote_dir = remote_dir
3032
self._cache_dir = cache_dir
3133
self._chunks = chunks
34+
self._storage_options = storage_options or {}
3235

3336
def download_chunk_from_index(self, chunk_index: int) -> None:
3437
chunk_filename = self._chunks[chunk_index]["filename"]
@@ -41,12 +44,14 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None:
4144

4245

4346
class S3Downloader(Downloader):
44-
def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]):
45-
super().__init__(remote_dir, cache_dir, chunks)
47+
def __init__(
48+
self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
49+
):
50+
super().__init__(remote_dir, cache_dir, chunks, storage_options)
4651
self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0
4752

4853
if not self._s5cmd_available:
49-
self._client = S3Client()
54+
self._client = S3Client(storage_options=self._storage_options)
5055

5156
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
5257
obj = parse.urlparse(remote_filepath)
@@ -88,11 +93,13 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
8893

8994

9095
class GCPDownloader(Downloader):
91-
def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]):
96+
def __init__(
97+
self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
98+
):
9299
if not _GOOGLE_STORAGE_AVAILABLE:
93100
raise ModuleNotFoundError(str(_GOOGLE_STORAGE_AVAILABLE))
94101

95-
super().__init__(remote_dir, cache_dir, chunks)
102+
super().__init__(remote_dir, cache_dir, chunks, storage_options)
96103

97104
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
98105
from google.cloud import storage
@@ -113,7 +120,7 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
113120
if key[0] == "/":
114121
key = key[1:]
115122

116-
client = storage.Client()
123+
client = storage.Client(**self._storage_options)
117124
bucket = client.bucket(bucket_name)
118125
blob = bucket.blob(key)
119126
blob.download_to_filename(local_filepath)
@@ -140,8 +147,10 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
140147
_DOWNLOADERS = {"s3://": S3Downloader, "gs://": GCPDownloader, "local:": LocalDownloaderWithCache, "": LocalDownloader}
141148

142149

143-
def get_downloader_cls(remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]) -> Downloader:
150+
def get_downloader_cls(
151+
remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
152+
) -> Downloader:
144153
for k, cls in _DOWNLOADERS.items():
145154
if str(remote_dir).startswith(k):
146-
return cls(remote_dir, cache_dir, chunks)
155+
return cls(remote_dir, cache_dir, chunks, storage_options)
147156
raise ValueError(f"The provided `remote_dir` {remote_dir} doesn't have a downloader associated.")

src/litdata/streaming/reader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def __init__(
169169
encryption: Optional[Encryption] = None,
170170
item_loader: Optional[BaseItemLoader] = None,
171171
serializers: Optional[Dict[str, Serializer]] = None,
172+
storage_options: Optional[dict] = {},
172173
) -> None:
173174
"""The BinaryReader enables to read chunked dataset in an efficient way.
174175
@@ -183,6 +184,7 @@ def __init__(
183184
item_loader: The chunk sampler to create sub arrays from a chunk.
184185
max_cache_size: The maximum cache size used by the reader when fetching the chunks.
185186
serializers: Provide your own serializers.
187+
storage_options: Additional connection options for accessing storage services.
186188
187189
"""
188190
super().__init__()
@@ -207,6 +209,7 @@ def __init__(
207209
self._item_loader = item_loader or PyTreeLoader()
208210
self._last_chunk_index: Optional[int] = None
209211
self._max_cache_size = int(os.getenv("MAX_CACHE_SIZE", max_cache_size or 0))
212+
self._storage_options = storage_options
210213

211214
def _get_chunk_index_from_index(self, index: int) -> Tuple[int, int]:
212215
# Load the config containing the index
@@ -224,6 +227,7 @@ def _try_load_config(self) -> Optional[ChunksConfig]:
224227
self._item_loader,
225228
self.subsampled_files,
226229
self.region_of_interest,
230+
self._storage_options,
227231
)
228232
return self._config
229233

src/litdata/utilities/dataset_utilities.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def subsample_streaming_dataset(
1919
subsample: float = 1.0,
2020
shuffle: bool = False,
2121
seed: int = 42,
22+
storage_options: Optional[Dict] = {},
2223
) -> Tuple[List[str], List[Tuple[int, int]]]:
2324
"""Subsample streaming dataset.
2425
@@ -46,7 +47,7 @@ def subsample_streaming_dataset(
4647
# Check if `index.json` file exists in cache path
4748
if not os.path.exists(cache_index_filepath) and isinstance(input_dir.url, str):
4849
assert input_dir.url is not None
49-
downloader = get_downloader_cls(input_dir.url, input_dir.path, [])
50+
downloader = get_downloader_cls(input_dir.url, input_dir.path, [], storage_options)
5051
downloader.download_file(os.path.join(input_dir.url, _INDEX_FILENAME), cache_index_filepath)
5152

5253
if os.path.exists(os.path.join(input_dir.path, _INDEX_FILENAME)):

tests/streaming/test_client.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,38 @@
66
from litdata.streaming import client
77

88

9+
def test_s3_client_with_storage_options(monkeypatch):
10+
boto3 = mock.MagicMock()
11+
monkeypatch.setattr(client, "boto3", boto3)
12+
13+
botocore = mock.MagicMock()
14+
monkeypatch.setattr(client, "botocore", botocore)
15+
16+
storage_options = {
17+
"region_name": "us-west-2",
18+
"endpoint_url": "https://custom.endpoint",
19+
"config": botocore.config.Config(retries={"max_attempts": 100}),
20+
}
21+
s3_client = client.S3Client(storage_options=storage_options)
22+
23+
assert s3_client.client
24+
25+
boto3.client.assert_called_with(
26+
"s3",
27+
region_name="us-west-2",
28+
endpoint_url="https://custom.endpoint",
29+
config=botocore.config.Config(retries={"max_attempts": 100}),
30+
)
31+
32+
s3_client = client.S3Client()
33+
34+
assert s3_client.client
35+
36+
boto3.client.assert_called_with(
37+
"s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"})
38+
)
39+
40+
941
def test_s3_client_without_cloud_space_id(monkeypatch):
1042
boto3 = mock.MagicMock()
1143
monkeypatch.setattr(client, "boto3", boto3)

tests/streaming/test_downloader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@ def test_gcp_downloader(tmpdir, monkeypatch, google_mock):
3636
mock_bucket.blob = MagicMock(return_value=mock_blob)
3737

3838
# Initialize the downloader
39-
downloader = GCPDownloader("gs://random_bucket", tmpdir, [])
39+
storage_options = {"project": "DUMMY_PROJECT"}
40+
downloader = GCPDownloader("gs://random_bucket", tmpdir, [], storage_options)
4041
local_filepath = os.path.join(tmpdir, "a.txt")
4142
downloader.download_file("gs://random_bucket/a.txt", local_filepath)
4243

4344
# Assert that the correct methods were called
45+
google_mock.cloud.storage.Client.assert_called_with(**storage_options)
4446
mock_client.bucket.assert_called_with("random_bucket")
4547
mock_bucket.blob.assert_called_with("a.txt")
4648
mock_blob.download_to_filename.assert_called_with(local_filepath)

0 commit comments

Comments
 (0)