-
Notifications
You must be signed in to change notification settings - Fork 3k
Allow downloading/processing/caching only specific splits #2249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7544616
43724e5
dff9fed
e0e1aa5
6a0c3fc
75ea3d4
d4d7f28
07c2c02
e231072
2f06ae4
84ac280
90be20e
3405398
3a137c5
f65be54
67e86a9
f1b6e3d
f733347
111e32a
c398de5
a001084
0678193
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
|
||
import enum | ||
import os | ||
from collections import defaultdict | ||
from datetime import datetime | ||
from functools import partial | ||
from typing import Dict, Optional, Union | ||
|
@@ -184,6 +185,10 @@ def download(self, url_or_urls): | |
downloaded_path(s): `str`, The downloaded paths matching the given input | ||
url_or_urls. | ||
""" | ||
if self._download_config.splits: | ||
if isinstance(url_or_urls, dict) and all(split in url_or_urls for split in self._download_config.splits): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would add an additional check just to avoid unwanted behaviors.
then this trick here would use Maybe you could check that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Edit: I meant a dict like this
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @lhoestq. Sorry, I'm not sure of understanding what you mean... 😅 What I am checking here is that all the keys in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you pass a dictionary like this:
then only the train or test keys will be kept, which I feel not intuitive. For example if the users asks to load the "train" split, then the main and secondary metadata won't be downloaded. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @lhoestq, I understand it now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See my comment below. |
||
url_or_urls = {split: url_or_urls[split] for split in self._download_config.splits} | ||
|
||
download_config = self._download_config.copy() | ||
download_config.extract_compressed_file = False | ||
# Default to using 16 parallel thread for downloading | ||
|
@@ -268,7 +273,11 @@ def extract(self, path_or_paths, num_proc=None): | |
path_or_paths = NestedDataStructure(path_or_paths) | ||
extracted_paths = NestedDataStructure(extracted_paths) | ||
self.extracted_paths.update(dict(zip(path_or_paths.flatten(), extracted_paths.flatten()))) | ||
return extracted_paths.data | ||
return ( | ||
extracted_paths.data | ||
if not isinstance(extracted_paths.data, dict) | ||
else defaultdict(lambda: "<NOT_DOWNLOADED>", extracted_paths.data) | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you need this ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because the user implementation of the method If a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense ! Maybe use a more explicit placeholder that an empty string, in case users experience errors using this ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've set |
||
|
||
def download_and_extract(self, url_or_urls): | ||
"""Download and extract given url_or_urls. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -228,17 +228,28 @@ class DownloadConfig: | |
""" | ||
|
||
cache_dir: Optional[Union[str, Path]] = None | ||
force_download: bool = False | ||
force_download: Optional[bool] = None # default False | ||
_is_force_download_set_by_user: bool = True | ||
resume_download: bool = False | ||
local_files_only: bool = False | ||
proxies: Optional[Dict] = None | ||
user_agent: Optional[str] = None | ||
extract_compressed_file: bool = False | ||
force_extract: bool = False | ||
use_etag: bool = True | ||
use_etag: Optional[bool] = None # default True | ||
_is_use_etag_set_by_user: bool = True | ||
num_proc: Optional[int] = None | ||
max_retries: int = 1 | ||
use_auth_token: Optional[Union[str, bool]] = None | ||
splits: Optional[list] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The DownloadConfig class was used until now as a the class that defines the parameters that we pass to the Maybe this could be an argument of the DownloadManager itself. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have not a clear opinion on this yet (I'm going to think about it), but I can tell you that for readability, a data class called In my opinion I find quite sensible to include a download configuration setting (whether to download all or only some of the files) in a As a side note, I am planning to refactor also There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The separation between the two configs makes sense to me, and I totally agree with you with the naming. Regarding cached_path, note that this is a function that we have in common with the other libraries (transformers and huggingface_hub) so it could be worth discussing with the other maintainers about the changes we want to do. |
||
|
||
def __post_init__(self): | ||
if self.use_etag is None: | ||
self.use_etag = True | ||
self._is_use_etag_set_by_user = False | ||
if self.force_download is None: | ||
self.force_download = False | ||
self._is_force_download_set_by_user = False | ||
|
||
def copy(self) -> "DownloadConfig": | ||
return self.__class__(**{k: copy.deepcopy(v) for k, v in self.__dict__.items()}) | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -5,6 +5,7 @@ | |||||||||||||||||
import time | ||||||||||||||||||
from functools import partial | ||||||||||||||||||
from hashlib import sha256 | ||||||||||||||||||
from pathlib import Path | ||||||||||||||||||
from unittest import TestCase | ||||||||||||||||||
from unittest.mock import patch | ||||||||||||||||||
|
||||||||||||||||||
|
@@ -17,6 +18,7 @@ | |||||||||||||||||
from datasets.dataset_dict import DatasetDict, IterableDatasetDict | ||||||||||||||||||
from datasets.iterable_dataset import IterableDataset | ||||||||||||||||||
from datasets.load import prepare_module | ||||||||||||||||||
from datasets.utils.file_utils import DownloadConfig | ||||||||||||||||||
|
||||||||||||||||||
from .utils import ( | ||||||||||||||||||
OfflineSimulationMode, | ||||||||||||||||||
|
@@ -317,3 +319,39 @@ def test_load_from_disk_with_default_in_memory( | |||||||||||||||||
|
||||||||||||||||||
with assert_arrow_memory_increases() if expected_in_memory else assert_arrow_memory_doesnt_increase(): | ||||||||||||||||||
_ = load_from_disk(dataset_path) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
class TestLoadDatasetOnlySplits: | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also check this scenario:
The caching mechanism should notice that the "test" set is missing and download and prepare it. If I'm not wrong currently this would fail because of this line: datasets/src/datasets/builder.py Lines 504 to 511 in 097129d
Instead of checking for the cache_dir, we should check for the actual arrow split file to exist inside cache_dir There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Current implementation of I assumed that if the user passes a parameter to download only one split, then they are aware that only that split is in the cache and only that split can be loaded. In order to force a subsequent download of other split, they should pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a big fan of that. In my opinion it should detect that the requested split is missing and generate it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is also a breaking change we can't afford IMO |
||||||||||||||||||
def test_load_dataset_local_only_splits_processed_files(self, dataset_loading_script_dir, data_dir, tmp_path): | ||||||||||||||||||
download_config = DownloadConfig(splits=["test"]) | ||||||||||||||||||
cache_dir = str(tmp_path / "cache") | ||||||||||||||||||
datasetdict = datasets.load_dataset( | ||||||||||||||||||
dataset_loading_script_dir, | ||||||||||||||||||
data_dir=data_dir, | ||||||||||||||||||
cache_dir=cache_dir, | ||||||||||||||||||
download_config=download_config, | ||||||||||||||||||
) | ||||||||||||||||||
albertvillanova marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||
assert isinstance(datasetdict, DatasetDict) | ||||||||||||||||||
assert "train" not in datasetdict | ||||||||||||||||||
assert "test" in datasetdict | ||||||||||||||||||
dataset = datasetdict["test"] | ||||||||||||||||||
assert dataset.split == "test" | ||||||||||||||||||
assert dataset.shape == (10, 1) | ||||||||||||||||||
# pattern = "*/0.0.0/74c0095031cf868e2486de6e08bb3ca4a9f9de3a81b10af67a42aed21393e640/*.arrow" | ||||||||||||||||||
generated_arrow_files = sorted(Path(cache_dir, dataset.builder_name).glob("**/*.arrow")) | ||||||||||||||||||
assert len(generated_arrow_files) == 1 | ||||||||||||||||||
|
||||||||||||||||||
def test_load_dataset_from_hub_only_splits_downloaded_files(self, tmp_path): | ||||||||||||||||||
download_config = DownloadConfig(splits=["train"]) | ||||||||||||||||||
cache_dir = str(tmp_path / "cache") | ||||||||||||||||||
datasetdict = load_dataset(SAMPLE_DATASET_IDENTIFIER, cache_dir=cache_dir, download_config=download_config) | ||||||||||||||||||
assert isinstance(datasetdict, DatasetDict) | ||||||||||||||||||
assert "train" in datasetdict | ||||||||||||||||||
assert "validation" not in datasetdict | ||||||||||||||||||
dataset = datasetdict["train"] | ||||||||||||||||||
assert dataset.split == "train" | ||||||||||||||||||
assert dataset.shape == (2, 1) | ||||||||||||||||||
downloaded_files = set(str(path.stem) for path in Path(cache_dir, "downloads").glob("**/*")) | ||||||||||||||||||
assert len(downloaded_files) == 1 | ||||||||||||||||||
generated_arrow_files = sorted(Path(cache_dir, dataset.builder_name).glob("**/*.arrow")) | ||||||||||||||||||
assert len(generated_arrow_files) == 1 |
Uh oh!
There was an error while loading. Please reload this page.