Skip to content

Revert "Feat: Add support for reading LitData dataset published to HF" #320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ dataset = ld.StreamingDataset("s3://my-bucket/my-data", storage_options=aws_stor
gcp_storage_options={
"project": os.environ['PROJECT_ID'],
}
dataset = ld.StreamingDataset("gs://my-bucket/my-data", storage_options=gcp_storage_options)
dataset = ld.StreamingDataset("gcp://my-bucket/my-data", storage_options=gcp_storage_options)

# Read data from Azure
azure_storage_options={
Expand Down
1 change: 0 additions & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
coverage ==7.5.3
cryptography==42.0.8
huggingface-hub==0.24.5
mosaicml-streaming==0.8.0
pytest ==8.3.*
pytest-cov ==5.0.0
Expand Down
15 changes: 7 additions & 8 deletions src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,17 @@
_DEFAULT_CACHE_DIR = os.path.join(Path.home(), ".lightning", "chunks")

# This is required for full pytree serialization / deserialization support
_AZURE_STORAGE_AVAILABLE = RequirementCache("azure.storage.blob")
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
_LIGHTNING_CLOUD_AVAILABLE = RequirementCache("lightning-cloud")
_BOTO3_AVAILABLE = RequirementCache("boto3")
_TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio")
_ZSTD_AVAILABLE = RequirementCache("zstd")
_CRYPTOGRAPHY_AVAILABLE = RequirementCache("cryptography")
_GOOGLE_STORAGE_AVAILABLE = RequirementCache("google.cloud.storage")
_HUGGINGFACE_HUB_AVAILABLE = RequirementCache("huggingface-hub")
_LIGHTNING_CLOUD_AVAILABLE = RequirementCache("lightning-cloud")
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")
_TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio")
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
_AZURE_STORAGE_AVAILABLE = RequirementCache("azure.storage.blob")
_TQDM_AVAILABLE = RequirementCache("tqdm")
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
_ZSTD_AVAILABLE = RequirementCache("zstd")
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")

# DON'T CHANGE ORDER
_TORCH_DTYPES_MAPPING = {
Expand Down
58 changes: 1 addition & 57 deletions src/litdata/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,7 @@

from filelock import FileLock, Timeout

from litdata.constants import (
_AZURE_STORAGE_AVAILABLE,
_GOOGLE_STORAGE_AVAILABLE,
_HUGGINGFACE_HUB_AVAILABLE,
_INDEX_FILENAME,
)
from litdata.constants import _AZURE_STORAGE_AVAILABLE, _GOOGLE_STORAGE_AVAILABLE, _INDEX_FILENAME
from litdata.streaming.client import S3Client


Expand Down Expand Up @@ -169,56 +164,6 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
pass


class HFDownloader(Downloader):
def __init__(
self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
):
if not _HUGGINGFACE_HUB_AVAILABLE:
raise ModuleNotFoundError(str(_HUGGINGFACE_HUB_AVAILABLE))

super().__init__(remote_dir, cache_dir, chunks, storage_options)

def download_file(self, remote_filepath: str, local_filepath: str) -> None:
"""Download a file from the Hugging Face Hub.

The remote_filepath should be in the format `hf://<repo_type>/<repo_org>/<repo_name>/path`. For more
information, see
https://huggingface.co/docs/huggingface_hub/en/guides/hf_file_system#integrations.

"""
from huggingface_hub import hf_hub_download

obj = parse.urlparse(remote_filepath)

if obj.scheme != "hf":
raise ValueError(f"Expected obj.scheme to be `hf`, instead, got {obj.scheme} for remote={remote_filepath}")

if os.path.exists(local_filepath):
return

try:
with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0):
# Adapted from https://github.com/mosaicml/streaming/blob/main/streaming/base/storage/download.py#L292
# expected URL format: hf://datasets/<repo_org>/<repo_name>/path
_, _, _, repo_org, repo_name, path = remote_filepath.split("/", 5)
downloaded_path = hf_hub_download(
repo_id=f"{repo_org}/{repo_name}",
filename=path,
local_dir=self._cache_dir,
repo_type="dataset",
**self._storage_options,
)

# Move the downloaded file to the expected location if it's not already there.
if downloaded_path != local_filepath and os.path.exists(downloaded_path):
os.rename(downloaded_path, local_filepath)
os.rmdir(os.path.dirname(downloaded_path))

except Timeout:
# another process is responsible to download that file, continue
pass


class LocalDownloader(Downloader):
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
if not os.path.exists(remote_filepath):
Expand All @@ -238,7 +183,6 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
"s3://": S3Downloader,
"gs://": GCPDownloader,
"azure://": AzureDownloader,
"hf://": HFDownloader,
"local:": LocalDownloaderWithCache,
"": LocalDownloader,
}
Expand Down
3 changes: 1 addition & 2 deletions src/litdata/streaming/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def _resolve_dir(dir_path: Optional[Union[str, Dir]]) -> Dir:

assert isinstance(dir_path, str)

cloud_prefixes = ("s3://", "gs://", "azure://", "hf://")
if dir_path.startswith(cloud_prefixes):
if dir_path.startswith("s3://") or dir_path.startswith("gs://") or dir_path.startswith("azure://"):
return Dir(path=None, url=dir_path)

if dir_path.startswith("local:"):
Expand Down
10 changes: 0 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,6 @@ def azure_mock(monkeypatch):
return azure


@pytest.fixture()
def huggingface_mock(monkeypatch):
huggingface_hub = ModuleType("huggingface_hub")
monkeypatch.setitem(sys.modules, "huggingface_hub", huggingface_hub)
hf_hub_download = ModuleType("hf_hub_download")
monkeypatch.setitem(sys.modules, "huggingface_hub.hf_hub_download", hf_hub_download)
huggingface_hub.hf_hub_download = hf_hub_download
return huggingface_hub


@pytest.fixture()
def lightning_cloud_mock(monkeypatch):
lightning_cloud = ModuleType("lightning_cloud")
Expand Down
27 changes: 0 additions & 27 deletions tests/streaming/test_downloader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import contextlib
import os
from unittest import mock
from unittest.mock import MagicMock

from litdata.streaming.downloader import (
AzureDownloader,
GCPDownloader,
HFDownloader,
LocalDownloaderWithCache,
S3Downloader,
shutil,
Expand Down Expand Up @@ -74,31 +72,6 @@ def test_azure_downloader(tmpdir, monkeypatch, azure_mock):
mock_blob_data.readinto.assert_called()


@mock.patch("litdata.streaming.downloader._HUGGINGFACE_HUB_AVAILABLE", True)
def test_hf_downloader(tmpdir, monkeypatch, huggingface_mock):
mock_hf_hub_download = MagicMock()
huggingface_mock.hf_hub_download = mock_hf_hub_download

# Initialize the downloader
storage_options = {}
downloader = HFDownloader("hf://datasets/random_org/random_repo", tmpdir, [], storage_options)
local_filepath = os.path.join(tmpdir, "a.txt")

# ignore filenotfound error for this test TODO: write a better test
with contextlib.suppress(FileNotFoundError):
downloader.download_file("hf://datasets/random_org/random_repo/a.txt", local_filepath)
# Assert that the correct methods were called
huggingface_mock.hf_hub_download.assert_called_once()
huggingface_mock.hf_hub_download.assert_called_with(
repo_id="random_org/random_repo", filename="a.txt", local_dir=tmpdir, repo_type="dataset"
)

# Test that the file is not downloaded if it already exists
with contextlib.suppress(FileNotFoundError):
downloader.download_file("hf://datasets/random_org/random_repo/a.txt", local_filepath)
huggingface_mock.hf_hub_download.assert_not_called()


def test_download_with_cache(tmpdir, monkeypatch):
# Create a file to download/cache
with open("a.txt", "w") as f:
Expand Down
Loading