Skip to content

Commit a6331a4

Browse files
authored
Merge branch 'main' into workers-skip-speedup
2 parents 932fd2d + c16d00e commit a6331a4

File tree

10 files changed

+154
-15
lines changed

10 files changed

+154
-15
lines changed

README.md

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,14 +265,38 @@ dataset = ld.StreamingDataset("s3://my-bucket/my-data", storage_options=aws_stor
265265
gcp_storage_options={
266266
"project": os.environ['PROJECT_ID'],
267267
}
268-
dataset = ld.StreamingDataset("gcp://my-bucket/my-data", storage_options=gcp_storage_options)
268+
dataset = ld.StreamingDataset("gs://my-bucket/my-data", storage_options=gcp_storage_options)
269269

270270
# Read data from Azure
271271
azure_storage_options={
272272
"account_url": f"https://{os.environ['AZURE_ACCOUNT_NAME']}.blob.core.windows.net",
273273
"credential": os.environ['AZURE_ACCOUNT_ACCESS_KEY']
274274
}
275275
dataset = ld.StreamingDataset("azure://my-bucket/my-data", storage_options=azure_storage_options)
276+
277+
# Read data from Hugging Face
278+
hf_storage_options={
279+
"use_auth_token": os.environ['HF_TOKEN']
280+
}
281+
dataset = StreamingDataset("hf://datasets/my-org/my-repo", storage_options=hf_storage_options)
282+
# Read from a nested directory
283+
dataset = StreamingDataset("hf://datasets/my-org/my-repo/dataset-1", storage_options=hf_storage_options)
284+
```
285+
286+
### Upload Data to Hugging Face
287+
288+
To upload data to Hugging Face, you can use the `huggingface-cli` command. Below is the command format:
289+
> For more information, checkout the [Hugging Face documentation](https://huggingface.co/docs/datasets/main/en/share#huggingface-cli-upload).
290+
291+
```sh
292+
$ huggingface-cli upload [dataset_repo_id] [local_path] [path_in_repo] --repo-type dataset --token=[HF_TOKEN]
293+
# Example: Uploading to the root of the repository
294+
# huggingface-cli upload my-org/my-repo ./my-data --repo-type dataset --token=hf_****
295+
296+
# Example: Uploading to a nested directory within the repository
297+
# huggingface-cli upload my-org/my-repo ./my-data dataset-1 --repo-type dataset --token=hf_****
298+
299+
# Note: If already logged in, you can skip the token
276300
```
277301

278302
</details>

requirements/test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
coverage ==7.5.3
22
cryptography==42.0.8
3+
huggingface-hub==0.24.5
34
mosaicml-streaming==0.8.0
45
pytest ==8.3.*
56
pytest-cov ==5.0.0

src/litdata/constants.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,18 @@
2626
_DEFAULT_CACHE_DIR = os.path.join(Path.home(), ".lightning", "chunks")
2727

2828
# This is required for full pytree serialization / deserialization support
29-
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
30-
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
31-
_LIGHTNING_CLOUD_AVAILABLE = RequirementCache("lightning-cloud")
29+
_AZURE_STORAGE_AVAILABLE = RequirementCache("azure.storage.blob")
3230
_BOTO3_AVAILABLE = RequirementCache("boto3")
33-
_TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio")
34-
_ZSTD_AVAILABLE = RequirementCache("zstd")
3531
_CRYPTOGRAPHY_AVAILABLE = RequirementCache("cryptography")
3632
_GOOGLE_STORAGE_AVAILABLE = RequirementCache("google.cloud.storage")
37-
_AZURE_STORAGE_AVAILABLE = RequirementCache("azure.storage.blob")
38-
_TQDM_AVAILABLE = RequirementCache("tqdm")
33+
_HUGGINGFACE_HUB_AVAILABLE = RequirementCache("huggingface-hub")
34+
_LIGHTNING_CLOUD_AVAILABLE = RequirementCache("lightning-cloud")
3935
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")
36+
_TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio")
37+
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
38+
_TQDM_AVAILABLE = RequirementCache("tqdm")
39+
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
40+
_ZSTD_AVAILABLE = RequirementCache("zstd")
4041

4142
# DON'T CHANGE ORDER
4243
_TORCH_DTYPES_MAPPING = {

src/litdata/streaming/dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,13 @@ def set_num_workers(self, num_workers: int) -> None:
197197
self.num_workers = num_workers or 1
198198

199199
def get_len(self, num_workers: int, batch_size: int) -> int:
200-
self.num_workers = num_workers
201-
self.batch_size = batch_size
200+
self.set_num_workers(num_workers)
201+
self.set_batch_size(batch_size)
202202
worker_env = _WorkerEnv.detect()
203203
if self.shuffler is None:
204204
cache = self._create_cache(worker_env=worker_env)
205205
self.shuffler = self._create_shuffler(cache)
206-
return self.shuffler.get_len(self.distributed_env, num_workers, batch_size, self.current_epoch)
206+
return self.shuffler.get_len(self.distributed_env, self.num_workers, self.batch_size, self.current_epoch)
207207

208208
def __iter__(self) -> "StreamingDataset":
209209
# When the StreamingDataset is used within map or optimize, let's refetch the distributed env.

src/litdata/streaming/downloader.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020

2121
from filelock import FileLock, Timeout
2222

23-
from litdata.constants import _AZURE_STORAGE_AVAILABLE, _GOOGLE_STORAGE_AVAILABLE, _INDEX_FILENAME
23+
from litdata.constants import (
24+
_AZURE_STORAGE_AVAILABLE,
25+
_GOOGLE_STORAGE_AVAILABLE,
26+
_HUGGINGFACE_HUB_AVAILABLE,
27+
_INDEX_FILENAME,
28+
)
2429
from litdata.streaming.client import S3Client
2530

2631

@@ -164,6 +169,56 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
164169
pass
165170

166171

172+
class HFDownloader(Downloader):
173+
def __init__(
174+
self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
175+
):
176+
if not _HUGGINGFACE_HUB_AVAILABLE:
177+
raise ModuleNotFoundError(str(_HUGGINGFACE_HUB_AVAILABLE))
178+
179+
super().__init__(remote_dir, cache_dir, chunks, storage_options)
180+
181+
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
182+
"""Download a file from the Hugging Face Hub.
183+
184+
The remote_filepath should be in the format `hf://<repo_type>/<repo_org>/<repo_name>/path`. For more
185+
information, see
186+
https://huggingface.co/docs/huggingface_hub/en/guides/hf_file_system#integrations.
187+
188+
"""
189+
from huggingface_hub import hf_hub_download
190+
191+
obj = parse.urlparse(remote_filepath)
192+
193+
if obj.scheme != "hf":
194+
raise ValueError(f"Expected obj.scheme to be `hf`, instead, got {obj.scheme} for remote={remote_filepath}")
195+
196+
if os.path.exists(local_filepath):
197+
return
198+
199+
try:
200+
with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0):
201+
# Adapted from https://github.com/mosaicml/streaming/blob/main/streaming/base/storage/download.py#L292
202+
# expected URL format: hf://datasets/<repo_org>/<repo_name>/path
203+
_, _, _, repo_org, repo_name, path = remote_filepath.split("/", 5)
204+
downloaded_path = hf_hub_download(
205+
repo_id=f"{repo_org}/{repo_name}",
206+
filename=path,
207+
local_dir=self._cache_dir,
208+
repo_type="dataset",
209+
**self._storage_options,
210+
)
211+
212+
# Move the downloaded file to the expected location if it's not already there.
213+
if downloaded_path != local_filepath and os.path.exists(downloaded_path):
214+
os.rename(downloaded_path, local_filepath)
215+
os.rmdir(os.path.dirname(downloaded_path))
216+
217+
except Timeout:
218+
# another process is responsible to download that file, continue
219+
pass
220+
221+
167222
class LocalDownloader(Downloader):
168223
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
169224
if not os.path.exists(remote_filepath):
@@ -183,6 +238,7 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
183238
"s3://": S3Downloader,
184239
"gs://": GCPDownloader,
185240
"azure://": AzureDownloader,
241+
"hf://": HFDownloader,
186242
"local:": LocalDownloaderWithCache,
187243
"": LocalDownloader,
188244
}

src/litdata/streaming/resolver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def _resolve_dir(dir_path: Optional[Union[str, Dir]]) -> Dir:
5252

5353
assert isinstance(dir_path, str)
5454

55-
if dir_path.startswith("s3://") or dir_path.startswith("gs://") or dir_path.startswith("azure://"):
55+
cloud_prefixes = ("s3://", "gs://", "azure://", "hf://")
56+
if dir_path.startswith(cloud_prefixes):
5657
return Dir(path=None, url=dir_path)
5758

5859
if dir_path.startswith("local:"):

tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,16 @@ def azure_mock(monkeypatch):
6464
return azure
6565

6666

67+
@pytest.fixture()
68+
def huggingface_mock(monkeypatch):
69+
huggingface_hub = ModuleType("huggingface_hub")
70+
monkeypatch.setitem(sys.modules, "huggingface_hub", huggingface_hub)
71+
hf_hub_download = ModuleType("hf_hub_download")
72+
monkeypatch.setitem(sys.modules, "huggingface_hub.hf_hub_download", hf_hub_download)
73+
huggingface_hub.hf_hub_download = hf_hub_download
74+
return huggingface_hub
75+
76+
6777
@pytest.fixture()
6878
def lightning_cloud_mock(monkeypatch):
6979
lightning_cloud = ModuleType("lightning_cloud")

tests/processing/test_functions.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def test_optimize_append_overwrite(tmpdir):
176176
assert ds[:] == [(i, i**2, i**3) for i in range(0, 5)]
177177

178178

179-
@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="too slow")
179+
@pytest.mark.skipif(sys.platform == "win32", reason="too slow")
180180
def test_optimize_checkpoint_in_none_and_append_mode(tmpdir):
181181
output_dir = str(tmpdir / "output_dir")
182182

@@ -188,6 +188,7 @@ def test_optimize_checkpoint_in_none_and_append_mode(tmpdir):
188188
chunk_size=1,
189189
num_workers=2,
190190
use_checkpoint=True,
191+
start_method="fork",
191192
)
192193

193194
# check that the checkpoints are created
@@ -201,6 +202,7 @@ def test_optimize_checkpoint_in_none_and_append_mode(tmpdir):
201202
chunk_size=1,
202203
num_workers=2,
203204
use_checkpoint=True,
205+
start_method="fork",
204206
)
205207

206208
ds = StreamingDataset(output_dir)
@@ -221,6 +223,7 @@ def test_optimize_checkpoint_in_none_and_append_mode(tmpdir):
221223
num_workers=2,
222224
use_checkpoint=True,
223225
mode="append",
226+
start_method="fork",
224227
)
225228

226229
# check that the checkpoints are created
@@ -240,6 +243,7 @@ def test_optimize_checkpoint_in_none_and_append_mode(tmpdir):
240243
num_workers=2,
241244
use_checkpoint=True,
242245
mode="append",
246+
start_method="fork",
243247
)
244248

245249
ds = StreamingDataset(output_dir)

tests/streaming/test_dataloader.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
import torch
55
from litdata.constants import _VIZ_TRACKER_AVAILABLE
6-
from litdata.streaming import CombinedStreamingDataset, StreamingDataLoader
6+
from litdata.streaming import Cache, CombinedStreamingDataset, StreamingDataLoader, StreamingDataset
77
from litdata.streaming import dataloader as streaming_dataloader_module
88
from torch import tensor
99

@@ -187,3 +187,18 @@ def test_custom_collate_multiworker():
187187

188188
# Try calling the state_dict. No error should follow
189189
_state_dict = dataloader.state_dict()
190+
191+
192+
def test_dataloader_no_workers(tmpdir):
193+
cache = Cache(input_dir=str(tmpdir), chunk_bytes="64MB")
194+
for i in range(1000):
195+
cache[i] = i
196+
197+
cache.done()
198+
cache.merge()
199+
200+
dataset = StreamingDataset(str(tmpdir), shuffle=True)
201+
dataloader = StreamingDataLoader(dataset)
202+
assert len(dataset) == 1000
203+
assert len(dataloader) == 1000
204+
assert len(dataset) == 1000

tests/streaming/test_downloader.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import contextlib
12
import os
23
from unittest import mock
34
from unittest.mock import MagicMock
45

56
from litdata.streaming.downloader import (
67
AzureDownloader,
78
GCPDownloader,
9+
HFDownloader,
810
LocalDownloaderWithCache,
911
S3Downloader,
1012
shutil,
@@ -72,6 +74,31 @@ def test_azure_downloader(tmpdir, monkeypatch, azure_mock):
7274
mock_blob_data.readinto.assert_called()
7375

7476

77+
@mock.patch("litdata.streaming.downloader._HUGGINGFACE_HUB_AVAILABLE", True)
78+
def test_hf_downloader(tmpdir, monkeypatch, huggingface_mock):
79+
mock_hf_hub_download = MagicMock()
80+
huggingface_mock.hf_hub_download = mock_hf_hub_download
81+
82+
# Initialize the downloader
83+
storage_options = {}
84+
downloader = HFDownloader("hf://datasets/random_org/random_repo", tmpdir, [], storage_options)
85+
local_filepath = os.path.join(tmpdir, "a.txt")
86+
87+
# ignore filenotfound error for this test TODO: write a better test
88+
with contextlib.suppress(FileNotFoundError):
89+
downloader.download_file("hf://datasets/random_org/random_repo/a.txt", local_filepath)
90+
# Assert that the correct methods were called
91+
huggingface_mock.hf_hub_download.assert_called_once()
92+
huggingface_mock.hf_hub_download.assert_called_with(
93+
repo_id="random_org/random_repo", filename="a.txt", local_dir=tmpdir, repo_type="dataset"
94+
)
95+
96+
# Test that the file is not downloaded if it already exists
97+
with contextlib.suppress(FileNotFoundError):
98+
downloader.download_file("hf://datasets/random_org/random_repo/a.txt", local_filepath)
99+
huggingface_mock.hf_hub_download.assert_not_called()
100+
101+
75102
def test_download_with_cache(tmpdir, monkeypatch):
76103
# Create a file to download/cache
77104
with open("a.txt", "w") as f:

0 commit comments

Comments
 (0)