Skip to content

Commit

Permalink
Add tfds.dataset_as_numpy enabling full tf.data pipelines with NumPy …
Browse files Browse the repository at this point in the history
…at the end.

Remove tfds.load_numpy.

PiperOrigin-RevId: 226423011
  • Loading branch information
Ryan Sepassi authored and Copybara-Service committed Dec 21, 2018
1 parent fbb8156 commit ed9397c
Show file tree
Hide file tree
Showing 11 changed files with 259 additions and 160 deletions.
34 changes: 16 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,39 +80,37 @@ print(info)
)
```

### NumPy Usage with `as_numpy()`
### NumPy Usage with `tfds.dataset_as_numpy`

As a convenience for users that have limited familiarity with TensorFlow,
`DatasetBuilder` has an `as_numpy()` method that yields batched NumPy arrays.
As a convenience for users that want simple NumPy arrays in their programs, you
can use `tfds.dataset_as_numpy` to return a generator that yields NumPy array
records out of a `tf.data.Dataset`. This allows you to build high-performance
input pipelines with `tf.data` but use whatever you'd like for your model
components.

```
mnist_builder = tfds.builder("mnist")
mnist_builder.download_and_prepare()
for example in mnist_builder.as_numpy(split=tfds.Split.TRAIN, batch_size=128):
train_ds = tfds.load("mnist", split=tfds.Split.TRAIN)
train_ds = train_ds.shuffle(1024).batch(128).repeat(5).prefetch(10)
for example in tfds.dataset_as_numpy(train_ds):
numpy_images, numpy_labels = example["image"], example["label"]
```

You can also get the entire dataset at once (if it fits in your machine's
memory) by using `batch_size=-1`:
You can also use `tfds.dataset_as_numpy` in conjunction with `batch_size=-1` to
get the full dataset in NumPy arrays from the returned `tf.Tensor` object:

```
mnist_builder = tfds.builder("mnist")
mnist_builder.download_and_prepare()
numpy_dataset = mnist_builder.as_numpy(split=tfds.Split.TRAIN, batch_size=-1)
train_data = tfds.load("mnist", split=tfds.Split.TRAIN, batch_size=-1)
numpy_data = tfds.dataset_as_numpy(train_data)
numpy_images, numpy_labels = numpy_dataset["image"], numpy_dataset["label"]
```

Note that `tf.data.Dataset` objects are iterable when running in Eager mode
(`tf.enable_eager_execution`), so you can use `builder.as_dataset`, build an
input pipeline, and then iterate through the dataset to get NumPy arrays as
well.

Note that the library still requires `tensorflow` as an internal dependency.

## Contributing a dataset

Thanks for considering a contribution. See the
[doc on adding a new dataset](https://github.com/tensorflow/datasets/tree/master/docs/add_dataset.md)
Thanks for considering a contribution! We're eager to grow the available set of
datasets. See the
[doc on adding a new dataset](https://github.com/tensorflow/datasets/tree/master/docs/add_dataset.md).

#### Disclaimers

Expand Down
43 changes: 0 additions & 43 deletions tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,49 +366,6 @@ def _build_single_dataset(self, split, shuffle_files, batch_size,
else:
return dataset

@api_utils.disallow_positional_args
def as_numpy(self, **as_dataset_kwargs):
# pylint: disable=g-doc-return-or-yield
"""Generates batches of NumPy arrays from the given `tfds.Split`.
Args:
**as_dataset_kwargs: Keyword arguments passed on to
`tfds.core.DatasetBuilder.as_dataset`.
Yields:
Feature dictionaries
`dict<str feature_name, numpy.array feature_val>`, or if `split=None`,
`dict` from `tfds.Split` to the feature dictionaries.
If `batch_size` is -1, will return a single dictionary containing
the entire dataset instead of yielding batches.
"""
# pylint: enable=g-doc-return-or-yield
# TF does not like it when you nest Graph/Session contexts, and because we
# may be operating on multiple splits here, we more directly control the
# graph and session creation/contexts to stay in the contexts for as little
# time as possible (basically keeping the contexts local and not persisting
# them).
with utils.maybe_with_graph() as graph:
dataset = self.as_dataset(**as_dataset_kwargs)

def ds_iter(ds):
for el in dataset_utils.iterate_over_dataset(ds, graph=graph):
yield el

wants_full_dataset = as_dataset_kwargs.get("batch_size") == -1
if wants_full_dataset:
# as_dataset returned Tensors, possibly tupleized with
# as_supervised=True
if tf.executing_eagerly():
return utils.map_nested(lambda t: t.numpy(), dataset, map_tuple=True)
else:
with utils.nogpu_session(graph=graph) as sess:
return sess.run(dataset)
else:
# as_dataset returned tf.data.Datasets
return utils.map_nested(ds_iter, dataset, dict_only=True)

def _get_data_dir(self, version=None):
"""Return the data directory of one dataset version.
Expand Down
33 changes: 13 additions & 20 deletions tensorflow_datasets/core/dataset_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import tensorflow as tf
from tensorflow_datasets.core import dataset_builder
from tensorflow_datasets.core import dataset_info
from tensorflow_datasets.core import dataset_utils
from tensorflow_datasets.core import features
from tensorflow_datasets.core import registered
from tensorflow_datasets.core import splits as splits_lib
Expand Down Expand Up @@ -104,7 +105,8 @@ def test_shared_generator(self):
splits_lib.Split.TRAIN, splits_lib.Split.TEST
]
train_data, test_data = [
[el["x"] for el in builder.as_numpy(split=split)]
[el["x"] for el in
dataset_utils.dataset_as_numpy(builder.as_dataset(split=split))]
for split in splits_list
]

Expand All @@ -123,12 +125,12 @@ def test_shared_generator(self):
@tf.contrib.eager.run_test_in_graph_and_eager_modes()
def test_load(self):
with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
dataset = registered.load_numpy(
dataset = registered.load(
name="dummy_dataset_shared_generator",
data_dir=tmp_dir,
download=True,
split=splits_lib.Split.TRAIN)
data = list(dataset)
data = list(dataset_utils.dataset_as_numpy(dataset))
self.assertEqual(20, len(data))
self.assertLess(data[0]["x"], 30)

Expand Down Expand Up @@ -197,7 +199,8 @@ def test_with_configs(self):
splits_list = [splits_lib.Split.TRAIN, splits_lib.Split.TEST]
for builder, incr in [(builder1, 1), (builder2, 2)]:
train_data, test_data = [
[el["x"] for el in builder.as_numpy(split=split)]
[el["x"] for el in
dataset_utils.dataset_as_numpy(builder.as_dataset(split=split))]
for split in splits_list
]

Expand All @@ -224,7 +227,8 @@ def setUp(self):

@tf.contrib.eager.run_test_in_graph_and_eager_modes()
def test_all_splits(self):
splits = self.builder.as_numpy(batch_size=-1)
splits = dataset_utils.dataset_as_numpy(
self.builder.as_dataset(batch_size=-1))
self.assertSetEqual(set(splits.keys()),
set([splits_lib.Split.TRAIN, splits_lib.Split.TEST]))

Expand All @@ -240,8 +244,8 @@ def test_all_splits(self):

@tf.contrib.eager.run_test_in_graph_and_eager_modes()
def test_with_batch_size(self):
items = list(self.builder.as_numpy(
split=splits_lib.Split.TRAIN + splits_lib.Split.TEST, batch_size=10))
items = list(dataset_utils.dataset_as_numpy(self.builder.as_dataset(
split=splits_lib.Split.TRAIN + splits_lib.Split.TEST, batch_size=10)))
# 3 batches of 10
self.assertEqual(3, len(items))
x1, x2, x3 = items[0]["x"], items[1]["x"], items[2]["x"]
Expand All @@ -250,21 +254,10 @@ def test_with_batch_size(self):
self.assertEqual(10, x3.shape[0])
self.assertEqual(sum(range(30)), int(x1.sum() + x2.sum() + x3.sum()))

@tf.contrib.eager.run_test_in_graph_and_eager_modes()
def test_as_numpy(self):
items = self.builder.as_numpy(split=splits_lib.Split.TRAIN, batch_size=-1)
self.assertEqual(items["x"].shape[0], 20)
self.assertLess(items["x"][0], 30)

count = 0
for _ in self.builder.as_numpy(split=splits_lib.Split.TRAIN):
count += 1
self.assertEqual(count, 20)

@tf.contrib.eager.run_test_in_graph_and_eager_modes()
def test_supervised_keys(self):
x, _ = self.builder.as_numpy(
split=splits_lib.Split.TRAIN, as_supervised=True, batch_size=-1)
x, _ = dataset_utils.dataset_as_numpy(self.builder.as_dataset(
split=splits_lib.Split.TRAIN, as_supervised=True, batch_size=-1))
self.assertEqual(x.shape[0], 20)


Expand Down
2 changes: 1 addition & 1 deletion tensorflow_datasets/core/dataset_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def get_dataset_feature_statistics(builder, split):
feature_to_min = {}
feature_to_max = {}

for example in dataset_utils.iterate_over_dataset(dataset):
for example in dataset_utils.dataset_as_numpy(dataset):
statistics.num_examples += 1

assert isinstance(example, dict)
Expand Down
90 changes: 78 additions & 12 deletions tensorflow_datasets/core/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import print_function

import tensorflow as tf
from tensorflow_datasets.core import api_utils
from tensorflow_datasets.core import utils


Expand Down Expand Up @@ -87,19 +88,84 @@ def instruction_ds_to_file_ds(instruction):
return dataset


def iterate_over_dataset(dataset, graph=None):
"""Yields numpy elements of `tf.data.Dataset`."""
def _eager_dataset_iterator(dataset):
for item in dataset:
flat = tf.contrib.framework.nest.flatten(item)
flat = [el.numpy() for el in flat]
yield tf.contrib.framework.nest.pack_sequence_as(item, flat)


def _graph_dataset_iterator(ds_item, graph=None):
with utils.nogpu_session(graph) as sess:
while True:
try:
yield sess.run(ds_item)
except tf.errors.OutOfRangeError:
break


@api_utils.disallow_positional_args(allowed=["dataset"])
def dataset_as_numpy(dataset, graph=None):
"""Converts a `tf.data.Dataset` to an iterable of NumPy arrays.
`dataset_as_numpy` converts a possibly nested structure of `tf.data.Dataset`s
and `tf.Tensor`s to iterables of NumPy arrays and NumPy arrays, respectively.
Args:
dataset: a possibly nested structure of `tf.data.Dataset`s and/or
`tf.Tensor`s.
graph: `tf.Graph`, optional, explicitly set the graph to use.
Returns:
A structure matching `dataset` where `tf.data.Datset`s are converted to
generators of NumPy arrays and `tf.Tensor`s are converted to NumPy arrays.
"""
nested_ds = dataset
del dataset

# Flatten
flat_ds = tf.contrib.framework.nest.flatten(nested_ds)
flat_np = []

# Type check for Tensors and Datasets
for ds_el in flat_ds:
types = [type(el) for el in flat_ds]
types = tf.contrib.framework.nest.pack_sequence_as(nested_ds, types)
if not isinstance(ds_el, (tf.Tensor, tf.data.Dataset)):
raise ValueError("Arguments to dataset_as_numpy must be tf.Tensors or "
"tf.data.Datasets. Got: %s" % types)

if tf.executing_eagerly():
for item in dataset:
flat = tf.contrib.framework.nest.flatten(item)
flat = [el.numpy() for el in flat]
yield tf.contrib.framework.nest.pack_sequence_as(item, flat)
# Eager mode
for ds_el in flat_ds:
if isinstance(ds_el, tf.Tensor):
np_el = ds_el.numpy()
elif isinstance(ds_el, tf.data.Dataset):
np_el = _eager_dataset_iterator(ds_el)
else:
assert False
flat_np.append(np_el)
else:
# Graph mode

# First create necessary graph ops
ds_iters = [None] * len(flat_ds)
with utils.maybe_with_graph(graph, create_if_none=False):
item = dataset.make_one_shot_iterator().get_next()
for i, ds_el in enumerate(flat_ds):
if isinstance(ds_el, tf.data.Dataset):
ds_iters[i] = ds_el.make_one_shot_iterator().get_next()

# Then create NumPy items
# Shared session for tf.Tensor runs
with utils.nogpu_session(graph) as sess:
while True:
try:
yield sess.run(item)
except tf.errors.OutOfRangeError:
break
for ds_iter, ds_el in zip(ds_iters, flat_ds):
if ds_iter is None:
# Tensor
np_el = sess.run(ds_el)
else:
# Dataset
np_el = _graph_dataset_iterator(ds_iter, graph)
flat_np.append(np_el)

# Nest
return tf.contrib.framework.nest.pack_sequence_as(nested_ds, flat_np)
Loading

0 comments on commit ed9397c

Please sign in to comment.