Skip to content

Commit

Permalink
Rename DatasetBuilder.numpy_iterator to DatasetBuilder.as_numpy a…
Browse files Browse the repository at this point in the history
…nd add batch_size=1 kwarg

PiperOrigin-RevId: 225289214
  • Loading branch information
Ryan Sepassi authored and Copybara-Service committed Dec 13, 2018
1 parent b7ec521 commit e426311
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 36 deletions.
27 changes: 20 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,33 @@ mnist_builder.download_and_prepare()
dataset = mnist_builder.as_dataset(split=tfds.Split.TRAIN)
```

### Non-TensorFlow Usage
### NumPy Usage with `as_numpy()`

All datasets are usable outside of TensorFlow with the `numpy_iterator`
method, which takes the same arguments as `as_dataset`.
As a convenience for users that have limited familiarity with TensorFlow,
`DatasetBuilder` has an `as_numpy()` method that yields batched NumPy arrays.

```python
import tensorflow_datasets as tfds
```
mnist_builder = tfds.builder("mnist")
mnist_builder.download_and_prepare()
for example in mnist_builder.as_numpy(split=tfds.Split.TRAIN, batch_size=128):
numpy_images, numpy_labels = example["image"], example["label"]
```

You can also get the entire dataset at once (if it fits in your machine's
memory) by using `batch_size=-1`:

```
mnist_builder = tfds.builder("mnist")
mnist_builder.download_and_prepare()
for element in mnist_builder.numpy_iterator(split=tfds.Split.TRAIN):
numpy_image, numpy_label = element["image"], element["label"]
numpy_dataset = mnist_builder.as_numpy(split=tfds.Split.TRAIN, batch_size=-1)
numpy_images, numpy_labels = numpy_dataset["image"], numpy_dataset["label"]
```

Note that `tf.data.Dataset` objects are iterable when running in Eager mode
(`tf.enable_eager_execution`), so you can use `builder.as_dataset`, build an
input pipeline, and then iterate through the dataset to get NumPy arrays as
well.

Note that the library still requires `tensorflow` as an internal dependency.

## Contributing a dataset
Expand Down
44 changes: 33 additions & 11 deletions tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,29 +258,51 @@ def as_dataset(self,
dataset = dataset.with_options(options)
return dataset

def numpy_iterator(self, **as_dataset_kwargs):
"""Generates numpy elements from the given `tfds.Split`.
This generator can be useful for non-TensorFlow programs.
def as_numpy(self, batch_size=1, **as_dataset_kwargs):
# pylint: disable=g-doc-return-or-yield
"""Generates batches of NumPy arrays from the given `tfds.Split`.
Args:
batch_size: `int`, batch size for the NumPy arrays. If -1 or None,
`as_numpy` will return the full dataset at once, each feature having its
own array.
**as_dataset_kwargs: Keyword arguments passed on to
`tfds.core.DatasetBuilder.as_dataset`.
Returns:
Generator yielding feature dictionaries
Yields:
Feature dictionaries
`dict<str feature_name, numpy.array feature_val>`.
If `batch_size` is -1 or None, will return a single dictionary containing
the entire dataset instead of yielding batches.
"""
def iterate():
# pylint: enable=g-doc-return-or-yield
def _as_numpy(batch_size):
"""Internal as_numpy."""
dataset = self.as_dataset(**as_dataset_kwargs)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset_utils.iterate_over_dataset(dataset)
# Use padded_batch so that features with unknown shape are supported.
padded_shapes = self.info.features.shape
if as_dataset_kwargs.get("as_supervised", False):
input_f, target_f = self.info.supervised_keys
padded_shapes = (padded_shapes[input_f], padded_shapes[target_f])
if (not batch_size or batch_size <= 0) and batch_size != -1:
raise ValueError("batch_size must be a positive number or -1 to return "
"the full dataset, got %s" % str(batch_size))
wants_full_dataset = batch_size == -1
if wants_full_dataset:
batch_size = self.info.num_examples or int(1e10)
dataset = dataset.padded_batch(batch_size, padded_shapes)
gen = dataset_utils.iterate_over_dataset(dataset)
if wants_full_dataset:
return next(gen)
else:
return gen

if tf.executing_eagerly():
return iterate()
return _as_numpy(batch_size)
else:
with tf.Graph().as_default():
return iterate()
return _as_numpy(batch_size)

def _get_data_dir(self, version=None):
"""Return the data directory of one dataset version.
Expand Down
23 changes: 12 additions & 11 deletions tensorflow_datasets/core/dataset_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,21 +236,22 @@ def setUpClass(cls):
def tearDownClass(cls):
test_utils.rm_tmp_dir(cls._tfds_tmp_dir)

def test_numpy_iterator(self):
def test_as_numpy(self):
builder = DummyDatasetSharedGenerator(data_dir=self._tfds_tmp_dir)
items = []
for item in builder.numpy_iterator(split=splits.Split.TRAIN):
items.append(item)
self.assertEqual(20, len(items))
self.assertLess(items[0]["x"], 30)
items = builder.as_numpy(split=splits.Split.TRAIN, batch_size=-1)
self.assertEqual(items["x"].shape[0], 20)
self.assertLess(items["x"][0], 30)

count = 0
for _ in builder.as_numpy(split=splits.Split.TRAIN):
count += 1
self.assertEqual(count, 20)

def test_supervised_keys(self):
builder = DummyDatasetSharedGenerator(data_dir=self._tfds_tmp_dir)
for item in builder.numpy_iterator(
split=splits.Split.TRAIN, as_supervised=True):
self.assertIsInstance(item, tuple)
self.assertEqual(len(item), 2)
break
x, _ = builder.as_numpy(
split=splits.Split.TRAIN, as_supervised=True, batch_size=-1)
self.assertEqual(x.shape[0], 20)


if __name__ == "__main__":
Expand Down
5 changes: 0 additions & 5 deletions tensorflow_datasets/core/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@
import tensorflow as tf
from tensorflow_datasets.core import utils

__all__ = [
"build_dataset",
"iterate_over_dataset",
]


def build_dataset(instruction_dicts,
dataset_from_file_fn,
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_datasets/core/splits_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _generate_examples(self, data):
})

def values(self, split):
return [v["value"] for v in self.numpy_iterator(split=split)]
return [int(v["value"]) for v in self.as_numpy(split=split)]


class SplitsUnitTest(tf.test.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_datasets/testing/dataset_builder_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def _assertAsDataset(self, builder):
dataset = builder.as_dataset(split=split_name)
compare_shapes_and_types(builder.info.features.get_tensor_info(),
dataset.output_types, dataset.output_shapes)
examples = list(builder.numpy_iterator(split=split_name))
examples = list(builder.as_numpy(split=split_name))
split_to_checksums[split_name] = set(checksum(rec) for rec in examples)
self.assertLen(examples, expected_examples_number)
for (split1, hashes1), (split2, hashes2) in itertools.combinations(
Expand Down

0 comments on commit e426311

Please sign in to comment.