Skip to content

Commit

Permalink
Add support for using beam in _split_generators()
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 290874617
  • Loading branch information
adarob authored and copybara-github committed Jan 22, 2020
1 parent 2ec8361 commit 851eddb
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 15 deletions.
23 changes: 22 additions & 1 deletion tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import abc
import functools
import inspect
import itertools
import os
import sys
Expand Down Expand Up @@ -853,13 +854,21 @@ def _prepare_split(self, split_generator, **kwargs):
"""
raise NotImplementedError()

def _make_split_generators_kwargs(self, prepare_split_kwargs):
"""Get kwargs for `self._split_generators()` from `prepare_split_kwargs`."""
del prepare_split_kwargs
return {}

def _download_and_prepare(self, dl_manager, **prepare_split_kwargs):
if not tf.io.gfile.exists(self._data_dir):
tf.io.gfile.makedirs(self._data_dir)

# Generating data for all splits
split_dict = splits_lib.SplitDict()
for split_generator in self._split_generators(dl_manager):
split_generators_kwargs = self._make_split_generators_kwargs(
prepare_split_kwargs)
for split_generator in self._split_generators(
dl_manager, **split_generators_kwargs):
if splits_lib.Split.ALL == split_generator.split_info.name:
raise ValueError(
"tfds.Split.ALL is a special split keyword corresponding to the "
Expand Down Expand Up @@ -1057,6 +1066,18 @@ def __init__(self, *args, **kwargs):
super(BeamBasedBuilder, self).__init__(*args, **kwargs)
self._beam_writers = {} # {split: beam_writer} mapping.

def _make_split_generators_kwargs(self, prepare_split_kwargs):
# Pass `pipeline` into `_split_generators()` from `prepare_split_kwargs` if
# it's in the call signature of `_split_generators()`.
# This allows for global preprocessing in beam.
split_generators_kwargs = {}
split_generators_arg_names = (
inspect.getargspec(self._split_generators).args if six.PY2 else
inspect.signature(self._split_generators).parameters.keys())
if "pipeline" in split_generators_arg_names:
split_generators_kwargs["pipeline"] = prepare_split_kwargs["pipeline"]
return split_generators_kwargs

@abc.abstractmethod
def _build_pcollection(self, pipeline, **kwargs):
"""Build the beam pipeline examples for each `SplitGenerator`.
Expand Down
63 changes: 49 additions & 14 deletions tensorflow_datasets/core/dataset_builder_beam_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,7 @@ def _split_generators(self, dl_manager):
),
]

def _build_pcollection(self, pipeline, num_examples):
"""Generate examples as dicts."""
examples = (
pipeline
| beam.Create(range(num_examples))
| beam.Map(_gen_example)
)

def _compute_metadata(self, examples, num_examples):
self.info.metadata["label_sum_%d" % num_examples] = (
examples
| beam.Map(lambda x: x[1]["label"])
Expand All @@ -83,6 +76,14 @@ def _build_pcollection(self, pipeline, num_examples):
| beam.Map(lambda x: x[1]["id"])
| beam.CombineGlobally(beam.combiners.MeanCombineFn()))

def _build_pcollection(self, pipeline, num_examples):
"""Generate examples as dicts."""
examples = (
pipeline
| beam.Create(range(num_examples))
| beam.Map(_gen_example)
)
self._compute_metadata(examples, num_examples)
return examples


Expand All @@ -94,6 +95,36 @@ def _gen_example(x):
})


class CommonPipelineDummyBeamDataset(DummyBeamDataset):

def _split_generators(self, dl_manager, pipeline):
del dl_manager

examples = (
pipeline
| beam.Create(range(1000))
| beam.Map(_gen_example)
)

return [
splits_lib.SplitGenerator(
name=splits_lib.Split.TRAIN,
gen_kwargs=dict(examples=examples, num_examples=1000),
),
splits_lib.SplitGenerator(
name=splits_lib.Split.TEST,
gen_kwargs=dict(examples=examples, num_examples=725),
),
]

def _build_pcollection(self, pipeline, examples, num_examples):
"""Generate examples as dicts."""
del pipeline
examples |= beam.Filter(lambda x: x[0] < num_examples)
self._compute_metadata(examples, num_examples)
return examples


class FaultyS3DummyBeamDataset(DummyBeamDataset):

VERSION = utils.Version("1.0.0")
Expand All @@ -107,24 +138,24 @@ def test_download_prepare_raise(self):
with self.assertRaisesWithPredicateMatch(ValueError, "no Beam Runner"):
builder.download_and_prepare()

def _assertBeamGeneration(self, dl_config):
def _assertBeamGeneration(self, dl_config, dataset_cls, dataset_name):
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
builder = DummyBeamDataset(data_dir=tmp_dir)
builder = dataset_cls(data_dir=tmp_dir)
builder.download_and_prepare(download_config=dl_config)

data_dir = os.path.join(tmp_dir, "dummy_beam_dataset", "1.0.0")
data_dir = os.path.join(tmp_dir, dataset_name, "1.0.0")
self.assertEqual(data_dir, builder._data_dir)

# Check number of shards
self._assertShards(
data_dir,
pattern="dummy_beam_dataset-test.tfrecord-{:05}-of-{:05}",
pattern="%s-test.tfrecord-{:05}-of-{:05}" % dataset_name,
# Liquid sharding is not guaranteed to always use the same number.
num_shards=builder.info.splits["test"].num_shards,
)
self._assertShards(
data_dir,
pattern="dummy_beam_dataset-train.tfrecord-{:05}-of-{:05}",
pattern="%s-train.tfrecord-{:05}-of-{:05}" % dataset_name,
num_shards=1,
)

Expand Down Expand Up @@ -177,7 +208,11 @@ def test_download_prepare(self):
dl_config = self._get_dl_config_if_need_to_run()
if not dl_config:
return
self._assertBeamGeneration(dl_config)
self._assertBeamGeneration(
dl_config, DummyBeamDataset, "dummy_beam_dataset")
self._assertBeamGeneration(
dl_config, CommonPipelineDummyBeamDataset,
"common_pipeline_dummy_beam_dataset")


if __name__ == "__main__":
Expand Down

0 comments on commit 851eddb

Please sign in to comment.