Skip to content

Commit

Permalink
* Add data generation configuration with BuilderConfig
Browse files Browse the repository at this point in the history
* Clean up memoization in py_utils
* Add some __repr__ properties for easier debugging

PiperOrigin-RevId: 224262282
  • Loading branch information
Ryan Sepassi authored and Copybara-Service committed Dec 6, 2018
1 parent 63daf63 commit 8fb7199
Show file tree
Hide file tree
Showing 10 changed files with 426 additions and 199 deletions.
2 changes: 2 additions & 0 deletions tensorflow_datasets/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""tensorflow_datasets.core."""

from tensorflow_datasets.core.dataset_builder import BuilderConfig
from tensorflow_datasets.core.dataset_builder import DatasetBuilder
from tensorflow_datasets.core.dataset_builder import GeneratorBasedDatasetBuilder

Expand All @@ -26,6 +27,7 @@
from tensorflow_datasets.core.splits import SplitInfo

__all__ = [
"BuilderConfig",
"DatasetBuilder",
"GeneratorBasedDatasetBuilder",
"DatasetInfo",
Expand Down
247 changes: 159 additions & 88 deletions tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,42 @@
import termcolor


__all__ = [
"DatasetBuilder",
"GeneratorBasedDatasetBuilder",
]

FORCE_REDOWNLOAD = download.GenerateMode.FORCE_REDOWNLOAD
REUSE_CACHE_IF_EXISTS = download.GenerateMode.REUSE_CACHE_IF_EXISTS
REUSE_DATASET_IF_EXISTS = download.GenerateMode.REUSE_DATASET_IF_EXISTS


class BuilderConfig(object):
"""Base class for data configuration.
DatasetBuilder subclasses with data configuration options should subclass
`BuilderConfig` and add their own properties.
"""

def __init__(self, name, version=None, description=None):
self._name = name
self._version = version
self._description = description

@property
def name(self):
return self._name

@property
def version(self):
return self._version

@property
def description(self):
return self._description

def __repr__(self):
return "<{cls_name} name={name}, version={version}>".format(
cls_name=type(self).__name__,
name=self.name,
version=self.version or "None")


@six.add_metaclass(registered.RegisteredDataset)
class DatasetBuilder(object):
"""Abstract base class for datasets.
Expand All @@ -73,17 +99,24 @@ class DatasetBuilder(object):
# Name of the dataset, filled by metaclass based on class name.
name = None

# Named configurations that modify the data generated by download_and_prepare.
DATA_CONFIGS = []

@api_utils.disallow_positional_args
def __init__(self, data_dir=None):
def __init__(self, data_dir=None, config=None):
"""Construct a DatasetBuilder.
Callers must pass arguments as keyword arguments.
Args:
data_dir: (str) directory to read/write data. Defaults to
"~/tensorflow_datasets".
config: (`tfds.core.BuilderConfig` or `str` name) optional configuration
for the dataset that affects the data generated on disk. Different
`builder_config`s will have their own subdirectories and versions.
"""
self._data_dir_root = os.path.expanduser(data_dir or constants.DATA_DIR)
self._builder_config = self._create_builder_config(config)
# Get the last dataset if it exists (or None otherwise)
self._data_dir = self._get_data_dir()

Expand Down Expand Up @@ -117,75 +150,51 @@ def download_and_prepare(
data is stored. Defaults to
"~/tensorflow-datasets/manual/{dataset_name}".
mode: `tfds.GenerateMode`: Mode to FORCE_REDOWNLOAD,
or REUSE_DATASET_IF_EXISTS. Default to REUSE_DATASET_IF_EXISTS.
or REUSE_DATASET_IF_EXISTS. Defaults to REUSE_DATASET_IF_EXISTS.
compute_stats: `boolean` If True, compute statistics over the generated
data and write the `tfds.core.DatasetInfo` protobuf to disk.
Raises:
ValueError: If the user defines both cache_dir and dl_manager
"""

mode = mode and download.GenerateMode(mode) or REUSE_DATASET_IF_EXISTS
mode = (mode and download.GenerateMode(mode)) or REUSE_DATASET_IF_EXISTS
if (self._data_dir and mode == REUSE_DATASET_IF_EXISTS):
tf.logging.info("Reusing dataset %s (%s)", self.name, self._data_dir)
return

download_dir = download_dir or os.path.join(self._data_dir_root,
"downloads")
extract_dir = extract_dir or os.path.join(self._data_dir_root, "extracted")
manual_dir = manual_dir or os.path.join(self._data_dir_root, "manual")
manual_dir = os.path.join(manual_dir, self.name)

# Create the download manager
dl_manager = download.DownloadManager(
dataset_name=self.name,
checksums=self.info.download_checksums,
dl_manager = self._make_download_manager(
download_dir=download_dir,
extract_dir=extract_dir,
manual_dir=manual_dir,
force_download=(mode == FORCE_REDOWNLOAD),
force_extraction=(mode == FORCE_REDOWNLOAD),
)
mode=mode)

# Otherwise, create a new version in a new data_dir.
data_dir = self._get_data_dir(version=self.info.version)
if tf.gfile.Exists(data_dir):
# If generation is determinism, the dataset can be re-generated and raise
# an error only if generated files are different
# Create a new version in a new data_dir.
self._data_dir = self._get_data_dir(version=self.info.version)
if tf.gfile.Exists(self._data_dir):
# TODO(tfds): If generation is deterministic, the dataset can be
# re-generated and raise an error only if generated files are different
raise ValueError(
"Trying to overwrite an existing dataset {} at {}. A dataset with "
"the same version {} already exists. If the dataset has changed, "
"please update the version number.".format(
self.name, data_dir, self.info.version))
tf.logging.info("Generating dataset %s (%s)", self.name, data_dir)

self._check_available_size(data_dir)

# Wrap the Dataset generation in a .incomplete directory
with file_format_adapter.incomplete_dir(data_dir) as data_dir_tmp:
# Modify the data_dir here to avoid having to forward it to every sub
# function
self._data_dir = data_dir_tmp

# Download data, generate the tf.train.Example
self._download_and_prepare(dl_manager=dl_manager)

# Update the DatasetInfo metadata by computing statistics from the data.
if compute_stats:
# Update the info object with the statistics and schema.
# Note: self.info already contains static information about the dataset
self.info.compute_dynamic_properties(self)

# Write DatasetInfo to disk, even if we haven't computed the statistics.
self.info.write_to_directory(self._data_dir)

# Once the data has been fully generated in the temporary directory,
# we restore set data_dir to it's final location by renaming the
# tmp_dir => data_dir (when exiting the context manager).
# Using a temporary directory ensure that the loaded data is not corrupted
# as only data having completed generation without crash (during the data
# generation, stat computation, data info writing,...) can be loaded.
self._data_dir = data_dir
"please update the version number.".format(self.name, self._data_dir,
self.info.version))
tf.logging.info("Generating dataset %s (%s)", self.name, self._data_dir)
self._log_download_bytes()

# Create a tmp dir and rename to self._data_dir on successful exit.
with file_format_adapter.incomplete_dir(self._data_dir) as tmp_data_dir:
# Temporarily assign _data_dir to tmp_data_dir to avoid having to forward
# it to every sub function.
with utils.temporary_assignment(self, "_data_dir", tmp_data_dir):
self._download_and_prepare(dl_manager=dl_manager)

# Update the DatasetInfo metadata by computing statistics from the data.
if compute_stats:
self.info.compute_dynamic_properties(self)

# Write DatasetInfo to disk, even if we haven't computed the statistics.
self.info.write_to_directory(self._data_dir)

@api_utils.disallow_positional_args
def as_dataset(self,
Expand Down Expand Up @@ -221,8 +230,8 @@ def as_dataset(self,
if isinstance(split, six.string_types):
split = splits.NamedSplit(split)

# Automatically activate shuffling if training
if shuffle_files is None:
# Shuffle files if training
shuffle_files = split == splits.Split.TRAIN

dataset = self._as_dataset(split=split, shuffle_files=shuffle_files)
Expand Down Expand Up @@ -272,44 +281,49 @@ def _get_data_dir(self, version=None):
Args:
version: (str) If specified, return the data_dir associated with the
given version
given version.
Returns:
data_dir: (str)
If version is given, return the data_dir associated with this version.
Otherwise, automatically extract the last version from the directory.
If no previous version is found, return None.
"""
data_root_dir = os.path.join(self._data_dir_root, self.name)
if version is not None:
return os.path.join(data_root_dir, version)

# Get the most recent directory
if tf.gfile.Exists(data_root_dir):
version_dirnames = {}
for filename in tf.gfile.ListDirectory(data_root_dir):
try:
version_dirnames[filename] = utils.str_to_version(filename)
except ValueError: # Invalid version (ex: incomplete data dir)
pass
# If found valid data directories, take the biggest version
if version_dirnames:
version_dirnames = [
k for k, _ in sorted(version_dirnames.items(), key=lambda x: x[-1])
]
return os.path.join(data_root_dir, version_dirnames[-1])
builder_config = self._builder_config
builder_data_dir = os.path.join(self._data_dir_root, self.name)
if builder_config:
builder_data_dir = os.path.join(builder_data_dir, builder_config.name)
if version:
return os.path.join(builder_data_dir, version)

if not tf.gfile.Exists(builder_data_dir):
return None

# Get the highest version directory
version_dirnames = []
for dir_name in tf.gfile.ListDirectory(builder_data_dir):
try:
version_dirnames.append((utils.str_to_version(dir_name), dir_name))
except ValueError: # Invalid version (ex: incomplete data dir)
pass
# If found valid data directories, take the biggest version
if version_dirnames:
version_dirnames.sort(reverse=True)
highest_version_dir = version_dirnames[0][1]
return os.path.join(builder_data_dir, highest_version_dir)

# No directory found
return None

def _check_available_size(self, data_dir):
"""Estimate the available size of the dataset."""
def _log_download_bytes(self):
# Print is intentional: we want this to always go to stdout so user has
# information needed to cancel download/preparation if needed.
# This comes right before the progress bar.
size_text = units.size_str(self.info.size_in_bytes)
termcolor.cprint("Downloading / extracting dataset %s (%s) to %s..." %
(self.name, size_text, data_dir), attrs=["bold"])
termcolor.cprint(
"Downloading / extracting dataset %s (%s) to %s..." %
(self.name, size_text, self._data_dir),
attrs=["bold"])
# TODO(tfds): Should try to estimate the available free disk space (if
# possible) and raise an error if not.

Expand Down Expand Up @@ -357,6 +371,66 @@ def _as_dataset(self, split, shuffle_files=None):
"""
raise NotImplementedError

def _make_download_manager(self, download_dir, extract_dir, manual_dir, mode):
download_dir = download_dir or os.path.join(self._data_dir_root,
"downloads")
extract_dir = extract_dir or os.path.join(self._data_dir_root, "extracted")
manual_dir = manual_dir or os.path.join(self._data_dir_root, "manual")
manual_dir = os.path.join(manual_dir, self.name)

return download.DownloadManager(
dataset_name=self.name,
checksums=self.info.download_checksums,
download_dir=download_dir,
extract_dir=extract_dir,
manual_dir=manual_dir,
force_download=(mode == FORCE_REDOWNLOAD),
force_extraction=(mode == FORCE_REDOWNLOAD),
)

@property
def builder_config(self):
return self._builder_config

def _create_builder_config(self, builder_config):
"""Create and validate BuilderConfig object."""
if not builder_config:
return
if isinstance(builder_config, six.string_types):
name = builder_config
builder_config = self.builder_configs.get(name)
if builder_config is None:
raise ValueError("BuilderConfig %s not found. Available: %s" %
(name, list(self.builder_configs.keys())))
name = builder_config.name
if not name:
raise ValueError("BuilderConfig must have a name, got %s" % name)
is_custom = name not in self.builder_configs
if is_custom:
tf.logging.warning("Using custom data configuration %s", name)
else:
if builder_config is not self.builder_configs[name]:
raise ValueError(
"Cannot name a custom BuilderConfig the same as an available "
"BuilderConfig. Change the name. Available BuilderConfigs: %s" %
(list(self.builder_configs.keys())))
if not builder_config.version:
raise ValueError("BuilderConfig %s must have a version" % name)
if not builder_config.description:
raise ValueError("BuilderConfig %s must have a description" % name)
return builder_config

@utils.classproperty
@classmethod
@utils.memoize()
def builder_configs(cls):
config_dict = {config.name: config for config in cls.DATA_CONFIGS}
if len(config_dict) != len(cls.DATA_CONFIGS):
names = [config.name for config in cls.DATA_CONFIGS]
raise ValueError(
"Names in DATA_CONFIGS must not be duplicated. Got %s" % names)
return config_dict


class GeneratorBasedDatasetBuilder(DatasetBuilder):
"""Base class for datasets with data generation based on dict generators.
Expand Down Expand Up @@ -469,17 +543,17 @@ def _generate_examples(self, **kwargs):
**kwargs: (dict) Arguments forwarded from the SplitGenerator.gen_kwargs
Yields:
example: (dict) Sample dict<str feature_name, feature_value>. The example
should usually be encoded with
`self.info.features.encode_example({...})`
example: (`dict<str feature_name, feature_value>`), a feature dictionary
ready to be written to disk. The example should usually be encoded with
`self.info.features.encode_example({...})`.
"""
raise NotImplementedError()

def _download_and_prepare(self, dl_manager):
if not tf.gfile.Exists(self._data_dir):
tf.gfile.MakeDirs(self._data_dir)

# Generating datata for all splits
# Generating data for all splits
split_dict = splits.SplitDict()
for split_generator in self._split_generators(dl_manager):
# Keep track of all split_info
Expand All @@ -497,9 +571,6 @@ def _download_and_prepare(self, dl_manager):
output_files,
)

# TODO(afrozm): Make it so that basic split information is known without
# having to call download_and_prepare. Maybe dataset definitions should
# include it.
# Update the info object with the splits.
self.info.splits = split_dict

Expand Down
Loading

0 comments on commit 8fb7199

Please sign in to comment.