Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
7544616
Add option only_splits to DownloadConfig
albertvillanova Apr 22, 2021
43724e5
Test load_datasets only_splits
albertvillanova Apr 22, 2021
dff9fed
Filter split_generators in only_splits
albertvillanova Apr 22, 2021
e0e1aa5
Mark DownloadConfig params with default values
albertvillanova Apr 22, 2021
6a0c3fc
Handle passed DownloadConfig with only_splits only
albertvillanova Apr 22, 2021
75ea3d4
Fix test
albertvillanova Apr 22, 2021
d4d7f28
Fix DownloadConfig
albertvillanova Apr 22, 2021
07c2c02
Fix MockDownloadManager
albertvillanova Apr 22, 2021
e231072
Refactorize test
albertvillanova Apr 23, 2021
2f06ae4
Test downloaded files by load_datasets only_splits
albertvillanova Apr 23, 2021
84ac280
Download only splits in only_splits
albertvillanova Apr 23, 2021
90be20e
Fix returned extracted paths to handle missing keys
albertvillanova Apr 23, 2021
3405398
Fix returned extracted paths if dict
albertvillanova Apr 23, 2021
3a137c5
Rename DownloadConfig.only_splits to DownloadConfig.splits
albertvillanova Apr 28, 2021
f65be54
Pass splits to _split_generators
albertvillanova Apr 28, 2021
67e86a9
Merge remote-tracking branch 'upstream/master' into load-only-splits
albertvillanova Apr 28, 2021
f1b6e3d
Fix undefined variable
albertvillanova Apr 28, 2021
f733347
Merge remote-tracking branch 'upstream/master' into load-only-splits
albertvillanova May 3, 2021
111e32a
Merge remote-tracking branch 'upstream/master' into load-only-splits
albertvillanova Jun 23, 2021
c398de5
Use a explicit not_downloaded placeholder in case users get errors
albertvillanova Jun 23, 2021
a001084
Rename DownloadConfig attributes
albertvillanova Jun 23, 2021
0678193
Merge remote-tracking branch 'upstream/master' into load-only-splits
albertvillanova Jun 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,15 @@ def download_and_prepare(
use_etag=False,
use_auth_token=use_auth_token,
) # We don't use etag for data files to speed up the process
else:
if not download_config.cache_dir:
download_config.cache_dir = os.path.join(self._cache_dir_root, "downloads")
if not download_config._is_force_download_set_by_user:
download_config.force_download = bool(download_mode == GenerateMode.FORCE_REDOWNLOAD)
if not download_config._is_use_etag_set_by_user:
download_config.use_etag = False
if download_config.use_auth_token is None:
download_config.use_auth_token = use_auth_token

dl_manager = DownloadManager(
dataset_name=self.name,
Expand Down Expand Up @@ -631,7 +640,19 @@ def _download_and_prepare(self, dl_manager, verify_infos, **prepare_split_kwargs
# Generating data for all splits
split_dict = SplitDict(dataset_name=self.name)
split_generators_kwargs = self._make_split_generators_kwargs(prepare_split_kwargs)
split_generators = self._split_generators(dl_manager, **split_generators_kwargs)
try:
split_generators = self._split_generators(
dl_manager, splits=dl_manager._download_config.splits, **split_generators_kwargs
)
except TypeError:
split_generators = self._split_generators(dl_manager, **split_generators_kwargs)
# For downloaded splits not filtered by self._split_generators, filter now to avoid caching at least
if dl_manager._download_config.splits:
split_generators = [
split_generator
for split_generator in split_generators
if split_generator.name in dl_manager._download_config.splits
]

# Checksums verification
if verify_infos:
Expand Down
11 changes: 10 additions & 1 deletion src/datasets/utils/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import enum
import os
from collections import defaultdict
from datetime import datetime
from functools import partial
from typing import Dict, Optional, Union
Expand Down Expand Up @@ -184,6 +185,10 @@ def download(self, url_or_urls):
downloaded_path(s): `str`, The downloaded paths matching the given input
url_or_urls.
"""
if self._download_config.splits:
if isinstance(url_or_urls, dict) and all(split in url_or_urls for split in self._download_config.splits):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add an additional check just to avoid unwanted behaviors.
For example right now if a dataset script passes to the dl_manager a dict like this:

{"main_data": url_to_main_data, "secondary_data": url_to_sec_data}

then this trick here would use url_or_urls = {} since the keys of the dict are not split names.

Maybe you could check that sorted(url_or_urls.keys()) == sorted(self._download_config.splits) before filtering ?

Copy link
Member

@lhoestq lhoestq May 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit:

I meant a dict like this

{"main_metadata": url_to_main_data, "secondary_metadata": url_to_sec_data, "train": url_train_data, "test": url_test_data}

Copy link
Member Author

@albertvillanova albertvillanova Jun 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @lhoestq. Sorry, I'm not sure of understanding what you mean... 😅

What I am checking here is that all the keys in self._download_config.splits are a subset of the keys in url_or_urls.

Copy link
Member

@lhoestq lhoestq Jun 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you pass a dictionary like this:

{"main_metadata": url_to_main_data,
"secondary_metadata": url_to_sec_data,
"train": url_train_data,
"test": url_test_data}

then only the train or test keys will be kept, which I feel not intuitive.

For example if the users asks to load the "train" split, then the main and secondary metadata won't be downloaded.
You can fix that by keeping all the keys except the splits to ignore

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @lhoestq, I understand it now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment below.

url_or_urls = {split: url_or_urls[split] for split in self._download_config.splits}

download_config = self._download_config.copy()
download_config.extract_compressed_file = False
# Default to using 16 parallel thread for downloading
Expand Down Expand Up @@ -268,7 +273,11 @@ def extract(self, path_or_paths, num_proc=None):
path_or_paths = NestedDataStructure(path_or_paths)
extracted_paths = NestedDataStructure(extracted_paths)
self.extracted_paths.update(dict(zip(path_or_paths.flatten(), extracted_paths.flatten())))
return extracted_paths.data
return (
extracted_paths.data
if not isinstance(extracted_paths.data, dict)
else defaultdict(lambda: "<NOT_DOWNLOADED>", extracted_paths.data)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need this ?

Copy link
Member Author

@albertvillanova albertvillanova May 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the user implementation of the method _split_generators will access all possible keys in the extracted path dict and may certainly access them through __getitem__. If our downloaded paths haven't all possible splits (but only a subset), accessing the missing ones through __getitem__ will raise a KeyError exception.

If a dafaultdict is returned instead, trying to access a missing (not downloaded) splits will be silently ignored.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense ! Maybe use a more explicit placeholder that an empty string, in case users experience errors using this ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've set "<NOT_DOWNLOADED>" instead.


def download_and_extract(self, url_or_urls):
"""Download and extract given url_or_urls.
Expand Down
15 changes: 13 additions & 2 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,17 +228,28 @@ class DownloadConfig:
"""

cache_dir: Optional[Union[str, Path]] = None
force_download: bool = False
force_download: Optional[bool] = None # default False
_is_force_download_set_by_user: bool = True
resume_download: bool = False
local_files_only: bool = False
proxies: Optional[Dict] = None
user_agent: Optional[str] = None
extract_compressed_file: bool = False
force_extract: bool = False
use_etag: bool = True
use_etag: Optional[bool] = None # default True
_is_use_etag_set_by_user: bool = True
num_proc: Optional[int] = None
max_retries: int = 1
use_auth_token: Optional[Union[str, bool]] = None
splits: Optional[list] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DownloadConfig class was used until now as a the class that defines the parameters that we pass to the cached_path function, and it had no logic related to the DownloadManager or to datasets specifically.
Therefore I'm not sure this is best place to put this argument. Let me know what you think.

Maybe this could be an argument of the DownloadManager itself.

Copy link
Member Author

@albertvillanova albertvillanova May 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not a clear opinion on this yet (I'm going to think about it), but I can tell you that for readability, a data class called DownloadConfig should be clearly used by a DownloadManager class.

In my opinion I find quite sensible to include a download configuration setting (whether to download all or only some of the files) in a DownloadConfig object. And in order to avoid passing lots of parameters to load_dataset, it makes sense to pass the parameters related to download stuff in a DownloadConfig class.

As a side note, I am planning to refactor also cached_path because it contains too many different coupled functionalities: download, extract, load from cache... For the moment, I have extracted all Extract functionalities. But I am planning further refactoring...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The separation between the two configs makes sense to me, and I totally agree with you with the naming.
Maybe we could have DownloadConfig as a subclass of CachedPathConfig or something like that (or another relation between the two).

Regarding cached_path, note that this is a function that we have in common with the other libraries (transformers and huggingface_hub) so it could be worth discussing with the other maintainers about the changes we want to do.


def __post_init__(self):
if self.use_etag is None:
self.use_etag = True
self._is_use_etag_set_by_user = False
if self.force_download is None:
self.force_download = False
self._is_force_download_set_by_user = False

def copy(self) -> "DownloadConfig":
return self.__class__(**{k: copy.deepcopy(v) for k, v in self.__dict__.items()})
Expand Down
5 changes: 4 additions & 1 deletion src/datasets/utils/mock_download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pathlib import Path
from typing import Callable, List, Optional, Union

from .file_utils import cached_path, hf_github_url
from .file_utils import DownloadConfig, cached_path, hf_github_url
from .logging import get_logger
from .version import Version

Expand Down Expand Up @@ -61,6 +61,9 @@ def __init__(
self._dummy_file = None
self._bucket_url = None

# As DownloadManager
self._download_config = DownloadConfig()

@property
def dummy_file(self):
if self._dummy_file is None:
Expand Down
38 changes: 38 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
from functools import partial
from hashlib import sha256
from pathlib import Path
from unittest import TestCase
from unittest.mock import patch

Expand All @@ -17,6 +18,7 @@
from datasets.dataset_dict import DatasetDict, IterableDatasetDict
from datasets.iterable_dataset import IterableDataset
from datasets.load import prepare_module
from datasets.utils.file_utils import DownloadConfig

from .utils import (
OfflineSimulationMode,
Expand Down Expand Up @@ -317,3 +319,39 @@ def test_load_from_disk_with_default_in_memory(

with assert_arrow_memory_increases() if expected_in_memory else assert_arrow_memory_doesnt_increase():
_ = load_from_disk(dataset_path)


class TestLoadDatasetOnlySplits:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also check this scenario:

  1. users load a dataset with the "train" split only
  2. users reloads this dataset but this time asking for the "test" set only

The caching mechanism should notice that the "test" set is missing and download and prepare it. If I'm not wrong currently this would fail because of this line:

data_exists = os.path.exists(self._cache_dir)
if data_exists and download_mode == GenerateMode.REUSE_DATASET_IF_EXISTS:
logger.warning("Reusing dataset %s (%s)", self.name, self._cache_dir)
# We need to update the info in case some splits were added in the meantime
# for example when calling load_dataset from multiple workers.
self.info = self._load_info()
self.download_post_processing_resources(dl_manager)
return

Instead of checking for the cache_dir, we should check for the actual arrow split file to exist inside cache_dir

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current implementation of load_dataset assumes that if there is something in the cache, then it is all you can get, unless you pass force_download.

I assumed that if the user passes a parameter to download only one split, then they are aware that only that split is in the cache and only that split can be loaded. In order to force a subsequent download of other split, they should pass force_download.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big fan of that. In my opinion it should detect that the requested split is missing and generate it.
This will avoid some confusions for users

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also a breaking change we can't afford IMO

def test_load_dataset_local_only_splits_processed_files(self, dataset_loading_script_dir, data_dir, tmp_path):
download_config = DownloadConfig(splits=["test"])
cache_dir = str(tmp_path / "cache")
datasetdict = datasets.load_dataset(
dataset_loading_script_dir,
data_dir=data_dir,
cache_dir=cache_dir,
download_config=download_config,
)
assert isinstance(datasetdict, DatasetDict)
assert "train" not in datasetdict
assert "test" in datasetdict
dataset = datasetdict["test"]
assert dataset.split == "test"
assert dataset.shape == (10, 1)
# pattern = "*/0.0.0/74c0095031cf868e2486de6e08bb3ca4a9f9de3a81b10af67a42aed21393e640/*.arrow"
generated_arrow_files = sorted(Path(cache_dir, dataset.builder_name).glob("**/*.arrow"))
assert len(generated_arrow_files) == 1

def test_load_dataset_from_hub_only_splits_downloaded_files(self, tmp_path):
download_config = DownloadConfig(splits=["train"])
cache_dir = str(tmp_path / "cache")
datasetdict = load_dataset(SAMPLE_DATASET_IDENTIFIER, cache_dir=cache_dir, download_config=download_config)
assert isinstance(datasetdict, DatasetDict)
assert "train" in datasetdict
assert "validation" not in datasetdict
dataset = datasetdict["train"]
assert dataset.split == "train"
assert dataset.shape == (2, 1)
downloaded_files = set(str(path.stem) for path in Path(cache_dir, "downloads").glob("**/*"))
assert len(downloaded_files) == 1
generated_arrow_files = sorted(Path(cache_dir, dataset.builder_name).glob("**/*.arrow"))
assert len(generated_arrow_files) == 1