Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
ecc322f
Create CacheManager and move method _as_dataset
albertvillanova Apr 28, 2021
7d07159
Extract method _run_post_process
albertvillanova Apr 28, 2021
4e7610d
Fix style
albertvillanova Apr 28, 2021
d2dccca
Change signature of _run_post_process
albertvillanova Apr 28, 2021
6bff9e4
Pull up _run_post_process call
albertvillanova Apr 28, 2021
7a8ca40
Rename methods
albertvillanova Apr 28, 2021
c08116a
Push down methods
albertvillanova Apr 28, 2021
eac5926
Change CacheManager instantiation and _as_dataset signature
albertvillanova Apr 28, 2021
c6647af
Move method _build_a_dataset
albertvillanova Apr 28, 2021
a048deb
Move method _build_dataset
albertvillanova Apr 28, 2021
074cc55
Refactor inline _as_dataset
albertvillanova Apr 28, 2021
0d8ad78
Rename CacheManager._build_dataset to DatasetCacheManager.load
albertvillanova Apr 29, 2021
c78d25c
Merge remote-tracking branch 'upstream/master' into refactoring-2
albertvillanova Apr 29, 2021
f001e48
Rename cache to caching
albertvillanova Apr 29, 2021
505f8be
Extract DatasetCacheManager instantiation as attribute
albertvillanova Apr 29, 2021
6b03484
Fix style
albertvillanova Apr 29, 2021
7ab4aa3
Create docstring
albertvillanova Apr 29, 2021
2e3a337
Extract method save
albertvillanova Apr 29, 2021
ee95735
Move method save
albertvillanova Apr 29, 2021
f3d0d91
Extract DatasetCacheManager instantiation as attribute
albertvillanova Apr 29, 2021
210283b
Refactor inline _writer_batch_size
albertvillanova Apr 29, 2021
5dbcd5a
Fix with temporary .incomplete cache_dir
albertvillanova Apr 29, 2021
aa60949
Fix test
albertvillanova Apr 29, 2021
4f69811
Remove DEFAULT_WRITER_BATCH_SIZE and use config.DEFAULT_MAX_BATCH_SIZE
albertvillanova Apr 29, 2021
7ba643c
Pass tmp_cache_dir to save as parameter
albertvillanova Apr 30, 2021
cb3e29b
Merge remote-tracking branch 'upstream/master' into refactoring-2
albertvillanova Apr 30, 2021
ccd05e1
Fix refactor inline _writer_batch_size
albertvillanova Apr 30, 2021
af562b6
Pass features instead of info to save
albertvillanova Apr 30, 2021
332447e
Extract method _save_tables
albertvillanova Apr 30, 2021
52787e6
Move method _save_tables
albertvillanova Apr 30, 2021
68cba2a
Split save into save_examples and save_tables
albertvillanova Apr 30, 2021
4e89508
Fix ArrowBasedBuilder._prepare_split
albertvillanova Apr 30, 2021
ba73ac9
Remove instantiation of DatasetCacheManager from DatasetBuilder
albertvillanova Apr 30, 2021
7f9ba4f
Revert "Remove instantiation of DatasetCacheManager from DatasetBuilder"
albertvillanova Apr 30, 2021
0a27d25
Remove unused output_prefix
albertvillanova Apr 30, 2021
c566dfd
Fix removed unused output_prefix
albertvillanova Apr 30, 2021
1683219
Refactor BeamBasedBuilder constructor
albertvillanova Apr 30, 2021
ed7b096
Change order of returned by save
albertvillanova Apr 30, 2021
dc6003b
Add docstring to save
albertvillanova Apr 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions datasets/wino_bias/wino_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,8 @@ class WinoBias(datasets.GeneratorBasedBuilder):
# You will be able to load one or the other configurations in the following list with
# data = datasets.load_dataset('my_dataset', 'first_domain')
# data = datasets.load_dataset('my_dataset', 'second_domain')
def __init__(self, *args, writer_batch_size=None, **kwargs):
super(WinoBias, self).__init__(*args, **kwargs)
# Batch size used by the ArrowWriter
# It defines the number of samples that are kept in memory before writing them
# and also the length of the arrow chunks
# None means that the ArrowWriter will use its default value
self._writer_batch_size = writer_batch_size or 100
def __init__(self, *args, writer_batch_size=100, **kwargs):
super(WinoBias, self).__init__(*args, writer_batch_size=writer_batch_size, **kwargs)

BUILDER_CONFIGS = [
WinoBiasConfig(
Expand Down
230 changes: 94 additions & 136 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@

from . import config, utils
from .arrow_dataset import Dataset
from .arrow_reader import HF_GCP_BASE_URL, ArrowReader, DatasetNotOnHfGcs, MissingFilesOnHfGcs, ReadInstruction
from .arrow_writer import ArrowWriter, BeamWriter
from .arrow_reader import HF_GCP_BASE_URL, ArrowReader, DatasetNotOnHfGcs, MissingFilesOnHfGcs
from .arrow_writer import BeamWriter
from .caching import DatasetCacheManager
from .dataset_dict import DatasetDict
from .fingerprint import Hasher
from .info import DatasetInfo, DatasetInfosDict, PostProcessedInfo
Expand All @@ -43,7 +44,7 @@
from .utils.file_utils import DownloadConfig, is_remote_url
from .utils.filelock import FileLock
from .utils.info_utils import get_size_checksum_dict, verify_checksums, verify_splits
from .utils.logging import WARNING, get_logger
from .utils.logging import get_logger


logger = get_logger(__name__)
Expand Down Expand Up @@ -270,6 +271,7 @@ def __init__(

# Set download manager
self.dl_manager = None
self.dataset_cache_manager = DatasetCacheManager(cache_dir=self._cache_dir)

# Must be set for datasets that use 'data_dir' functionality - the ones
# that require users to do additional steps to download the data
Expand Down Expand Up @@ -729,116 +731,77 @@ def as_dataset(
"Constructing Dataset for split %s, from %s", split or ", ".join(self.info.splits), self._cache_dir
)

# By default, return all splits
if split is None:
split = {s: s for s in self.info.splits}
datasets = self.dataset_cache_manager.load(split, in_memory=in_memory, info=self.info, name=self.name)

# Create a dataset for each of the given splits
if run_post_process:
datasets = self._run_post_process(datasets, ignore_verifications)

if isinstance(datasets, dict):
datasets = DatasetDict(datasets)
return datasets

def _run_post_process(self, datasets, ignore_verifications: bool = False):
datasets = utils.map_nested(
partial(
self._build_single_dataset,
run_post_process=run_post_process,
self._run_a_post_process,
ignore_verifications=ignore_verifications,
in_memory=in_memory,
),
split,
datasets,
map_tuple=True,
)
if isinstance(datasets, dict):
datasets = DatasetDict(datasets)
return datasets

def _build_single_dataset(
self,
split: Union[str, ReadInstruction, Split],
run_post_process: bool,
ignore_verifications: bool,
in_memory: bool = False,
):
"""as_dataset for a single split."""
def _run_a_post_process(self, ds, ignore_verifications: bool = False):
verify_infos = not ignore_verifications
if isinstance(split, str):
split = Split(split)

# Build base dataset
ds = self._as_dataset(
split=split,
in_memory=in_memory,
)
if run_post_process:
for resource_file_name in self._post_processing_resources(split).values():
if os.sep in resource_file_name:
raise ValueError("Resources shouldn't be in a sub-directory: {}".format(resource_file_name))
resources_paths = {
resource_name: os.path.join(self._cache_dir, resource_file_name)
for resource_name, resource_file_name in self._post_processing_resources(split).items()
}
post_processed = self._post_process(ds, resources_paths)
if post_processed is not None:
ds = post_processed
recorded_checksums = {}
for resource_name, resource_path in resources_paths.items():
size_checksum = get_size_checksum_dict(resource_path)
recorded_checksums[resource_name] = size_checksum
if verify_infos:
if self.info.post_processed is None or self.info.post_processed.resources_checksums is None:
expected_checksums = None
else:
expected_checksums = self.info.post_processed.resources_checksums.get(split)
verify_checksums(expected_checksums, recorded_checksums, "post processing resources")
if self.info.post_processed is None:
self.info.post_processed = PostProcessedInfo()
if self.info.post_processed.resources_checksums is None:
self.info.post_processed.resources_checksums = {}
self.info.post_processed.resources_checksums[str(split)] = recorded_checksums
self.info.post_processing_size = sum(
checksums_dict["num_bytes"]
for split_checksums_dicts in self.info.post_processed.resources_checksums.values()
for checksums_dict in split_checksums_dicts.values()
for resource_file_name in self._post_processing_resources(ds.split).values():
if os.sep in resource_file_name:
raise ValueError("Resources shouldn't be in a sub-directory: {}".format(resource_file_name))
resources_paths = {
resource_name: os.path.join(self._cache_dir, resource_file_name)
for resource_name, resource_file_name in self._post_processing_resources(ds.split).items()
}
post_processed = self._post_process(ds, resources_paths)
if post_processed is not None:
ds = post_processed
recorded_checksums = {}
for resource_name, resource_path in resources_paths.items():
size_checksum = get_size_checksum_dict(resource_path)
recorded_checksums[resource_name] = size_checksum
if verify_infos:
if self.info.post_processed is None or self.info.post_processed.resources_checksums is None:
expected_checksums = None
else:
expected_checksums = self.info.post_processed.resources_checksums.get(ds.split)
verify_checksums(expected_checksums, recorded_checksums, "post processing resources")
if self.info.post_processed is None:
self.info.post_processed = PostProcessedInfo()
if self.info.post_processed.resources_checksums is None:
self.info.post_processed.resources_checksums = {}
self.info.post_processed.resources_checksums[str(ds.split)] = recorded_checksums
self.info.post_processing_size = sum(
checksums_dict["num_bytes"]
for split_checksums_dicts in self.info.post_processed.resources_checksums.values()
for checksums_dict in split_checksums_dicts.values()
)
if self.info.dataset_size is not None and self.info.download_size is not None:
self.info.size_in_bytes = (
self.info.dataset_size + self.info.download_size + self.info.post_processing_size
)
if self.info.dataset_size is not None and self.info.download_size is not None:
self.info.size_in_bytes = (
self.info.dataset_size + self.info.download_size + self.info.post_processing_size
)
self._save_info()
ds._info.post_processed = self.info.post_processed
ds._info.post_processing_size = self.info.post_processing_size
ds._info.size_in_bytes = self.info.size_in_bytes
if self.info.post_processed.features is not None:
if self.info.post_processed.features.type != ds.features.type:
raise ValueError(
"Post-processed features info don't match the dataset:\nGot\n{}\nbut expected something like\n{}".format(
self.info.post_processed.features, ds.features
)
self._save_info()
ds._info.post_processed = self.info.post_processed
ds._info.post_processing_size = self.info.post_processing_size
ds._info.size_in_bytes = self.info.size_in_bytes
if self.info.post_processed.features is not None:
if self.info.post_processed.features.type != ds.features.type:
raise ValueError(
"Post-processed features info don't match the dataset:\nGot\n{}\nbut expected something like\n{}".format(
self.info.post_processed.features, ds.features
)
else:
ds.info.features = self.info.post_processed.features

)
else:
ds.info.features = self.info.post_processed.features
return ds

def _as_dataset(self, split: Union[ReadInstruction, Split] = Split.TRAIN, in_memory: bool = False) -> Dataset:
"""Constructs a `Dataset`.

This is the internal implementation to overwrite called when user calls
`as_dataset`. It should read the pre-processed datasets files and generate
the `Dataset` object.

Args:
split: `datasets.Split` which subset of the data to read.
in_memory (bool, default False): Whether to copy the data in-memory.

Returns:
`Dataset`
"""

dataset_kwargs = ArrowReader(self._cache_dir, self.info).read(
name=self.name,
instructions=split,
split_infos=self.info.splits.values(),
in_memory=in_memory,
)
return Dataset(**dataset_kwargs)

def _post_process(self, dataset: Dataset, resources_paths: Dict[str, str]) -> Optional[Dataset]:
"""Run dataset transforms or add indexes"""
return None
Expand Down Expand Up @@ -922,19 +885,15 @@ class GeneratorBasedBuilder(DatasetBuilder):
# GeneratorBasedBuilder should have dummy data for tests by default
test_dummy_data = True

# Default batch size used by the ArrowWriter
# It defines the number of samples that are kept in memory before writing them
# and also the length of the arrow chunks
# None means that the ArrowWriter will use its default value
DEFAULT_WRITER_BATCH_SIZE = None

def __init__(self, *args, writer_batch_size=None, **kwargs):
super(GeneratorBasedBuilder, self).__init__(*args, **kwargs)
# Batch size used by the ArrowWriter
# It defines the number of samples that are kept in memory before writing them
# and also the length of the arrow chunks
# None means that the ArrowWriter will use its default value
self._writer_batch_size = writer_batch_size or self.DEFAULT_WRITER_BATCH_SIZE
self.dataset_cache_manager = DatasetCacheManager(
cache_dir=self._cache_dir, writer_batch_size=writer_batch_size
)

@abc.abstractmethod
def _generate_examples(self, **kwargs):
Expand Down Expand Up @@ -967,22 +926,20 @@ def _generate_examples(self, **kwargs):
raise NotImplementedError()

def _prepare_split(self, split_generator):
split_info = split_generator.split_info

fname = "{}-{}.arrow".format(self.name, split_generator.name)
fpath = os.path.join(self._cache_dir, fname)

total = split_generator.split_info.num_examples
generator = self._generate_examples(**split_generator.gen_kwargs)
not_verbose = bool(logger.getEffectiveLevel() > WARNING)
with ArrowWriter(features=self.info.features, path=fpath, writer_batch_size=self._writer_batch_size) as writer:
try:
for key, record in utils.tqdm(
generator, unit=" examples", total=split_info.num_examples, leave=False, disable=not_verbose
):
example = self.info.features.encode_example(record)
writer.write(example)
finally:
num_examples, num_bytes = writer.finalize()
split_generator_name = split_generator.name

# TODO: tmp_cache_dir instead of self.cache_dir because of:
# with utils.temporary_assignment(self, "_cache_dir", tmp_data_dir)
num_examples, num_bytes, _ = self.dataset_cache_manager.save(
generator,
split_generator_name,
total=total,
name=self.name,
features=self.info.features,
tmp_cache_dir=self.cache_dir,
)

split_generator.split_info.num_examples = num_examples
split_generator.split_info.num_bytes = num_bytes
Expand Down Expand Up @@ -1025,20 +982,22 @@ def _generate_examples(self, **kwargs):
raise NotImplementedError()

def _prepare_split(self, split_generator):
fname = "{}-{}.arrow".format(self.name, split_generator.name)
fpath = os.path.join(self._cache_dir, fname)

generator = self._generate_tables(**split_generator.gen_kwargs)
not_verbose = bool(logger.getEffectiveLevel() > WARNING)
with ArrowWriter(features=self.info.features, path=fpath) as writer:
for key, table in utils.tqdm(generator, unit=" tables", leave=False, disable=not_verbose):
writer.write_table(table)
num_examples, num_bytes = writer.finalize()
split_generator_name = split_generator.name

num_examples, num_bytes, writer_features = DatasetCacheManager(cache_dir=self._cache_dir).save(
generator,
split_generator_name,
units="tables",
name=self.name,
features=self.info.features,
tmp_cache_dir=self._cache_dir,
)

split_generator.split_info.num_examples = num_examples
split_generator.split_info.num_bytes = num_bytes
if self.info.features is None:
self.info.features = writer._features
self.info.features = writer_features


class MissingBeamOptions(ValueError):
Expand All @@ -1051,10 +1010,10 @@ class BeamBasedBuilder(DatasetBuilder):
# BeamBasedBuilder does not have dummy data for tests yet
test_dummy_data = False

def __init__(self, *args, **kwargs):
self._beam_runner = kwargs.pop("beam_runner", None)
self._beam_options = kwargs.pop("beam_options", None)
super(BeamBasedBuilder, self).__init__(*args, **kwargs)
def __init__(self, *args, beam_runner=None, beam_options=None, **kwargs):
super().__init__(*args, **kwargs)
self._beam_runner = beam_runner
self._beam_options = beam_options
self._beam_writers = {} # {split: beam_writer} mapping.

def _make_split_generators_kwargs(self, prepare_split_kwargs):
Expand Down Expand Up @@ -1170,8 +1129,7 @@ def _prepare_split(self, split_generator, pipeline):
import apache_beam as beam

split_name = split_generator.split_info.name
output_prefix = filename_prefix_for_split(self.name, split_name)
output_prefix = os.path.join(self._cache_dir, output_prefix)
_ = filename_prefix_for_split(self.name, split_name)

# To write examples to disk:
fname = "{}-{}.arrow".format(self.name, split_name)
Expand Down
Loading