From ed9397c72665d7cca7961a93e3311f8dca5ed692 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Thu, 20 Dec 2018 17:54:40 -0800 Subject: [PATCH] Add tfds.dataset_as_numpy enabling full tf.data pipelines with NumPy at the end. Remove tfds.load_numpy. PiperOrigin-RevId: 226423011 --- README.md | 34 +++-- tensorflow_datasets/core/dataset_builder.py | 43 ------ .../core/dataset_builder_test.py | 33 ++--- tensorflow_datasets/core/dataset_info.py | 2 +- tensorflow_datasets/core/dataset_utils.py | 90 +++++++++++-- .../core/dataset_utils_test.py | 124 ++++++++++++++++++ tensorflow_datasets/core/registered.py | 82 +++--------- tensorflow_datasets/core/splits_test.py | 3 +- tensorflow_datasets/core/test_utils.py | 2 +- tensorflow_datasets/public_api.py | 2 + .../testing/dataset_builder_testing.py | 4 +- 11 files changed, 259 insertions(+), 160 deletions(-) create mode 100644 tensorflow_datasets/core/dataset_utils_test.py diff --git a/README.md b/README.md index d146c6a8cc8..f268faaf765 100644 --- a/README.md +++ b/README.md @@ -80,39 +80,37 @@ print(info) ) ``` -### NumPy Usage with `as_numpy()` +### NumPy Usage with `tfds.dataset_as_numpy` -As a convenience for users that have limited familiarity with TensorFlow, -`DatasetBuilder` has an `as_numpy()` method that yields batched NumPy arrays. +As a convenience for users that want simple NumPy arrays in their programs, you +can use `tfds.dataset_as_numpy` to return a generator that yields NumPy array +records out of a `tf.data.Dataset`. This allows you to build high-performance +input pipelines with `tf.data` but use whatever you'd like for your model +components. ``` -mnist_builder = tfds.builder("mnist") -mnist_builder.download_and_prepare() -for example in mnist_builder.as_numpy(split=tfds.Split.TRAIN, batch_size=128): +train_ds = tfds.load("mnist", split=tfds.Split.TRAIN) +train_ds = train_ds.shuffle(1024).batch(128).repeat(5).prefetch(10) +for example in tfds.dataset_as_numpy(train_ds): numpy_images, numpy_labels = example["image"], example["label"] ``` -You can also get the entire dataset at once (if it fits in your machine's -memory) by using `batch_size=-1`: +You can also use `tfds.dataset_as_numpy` in conjunction with `batch_size=-1` to +get the full dataset in NumPy arrays from the returned `tf.Tensor` object: ``` -mnist_builder = tfds.builder("mnist") -mnist_builder.download_and_prepare() -numpy_dataset = mnist_builder.as_numpy(split=tfds.Split.TRAIN, batch_size=-1) +train_data = tfds.load("mnist", split=tfds.Split.TRAIN, batch_size=-1) +numpy_data = tfds.dataset_as_numpy(train_data) numpy_images, numpy_labels = numpy_dataset["image"], numpy_dataset["label"] ``` -Note that `tf.data.Dataset` objects are iterable when running in Eager mode -(`tf.enable_eager_execution`), so you can use `builder.as_dataset`, build an -input pipeline, and then iterate through the dataset to get NumPy arrays as -well. - Note that the library still requires `tensorflow` as an internal dependency. ## Contributing a dataset -Thanks for considering a contribution. See the -[doc on adding a new dataset](https://github.com/tensorflow/datasets/tree/master/docs/add_dataset.md) +Thanks for considering a contribution! We're eager to grow the available set of +datasets. See the +[doc on adding a new dataset](https://github.com/tensorflow/datasets/tree/master/docs/add_dataset.md). #### Disclaimers diff --git a/tensorflow_datasets/core/dataset_builder.py b/tensorflow_datasets/core/dataset_builder.py index 2772846fc77..ce8ccc7e156 100644 --- a/tensorflow_datasets/core/dataset_builder.py +++ b/tensorflow_datasets/core/dataset_builder.py @@ -366,49 +366,6 @@ def _build_single_dataset(self, split, shuffle_files, batch_size, else: return dataset - @api_utils.disallow_positional_args - def as_numpy(self, **as_dataset_kwargs): - # pylint: disable=g-doc-return-or-yield - """Generates batches of NumPy arrays from the given `tfds.Split`. - - Args: - **as_dataset_kwargs: Keyword arguments passed on to - `tfds.core.DatasetBuilder.as_dataset`. - - Yields: - Feature dictionaries - `dict`, or if `split=None`, - `dict` from `tfds.Split` to the feature dictionaries. - - If `batch_size` is -1, will return a single dictionary containing - the entire dataset instead of yielding batches. - """ - # pylint: enable=g-doc-return-or-yield - # TF does not like it when you nest Graph/Session contexts, and because we - # may be operating on multiple splits here, we more directly control the - # graph and session creation/contexts to stay in the contexts for as little - # time as possible (basically keeping the contexts local and not persisting - # them). - with utils.maybe_with_graph() as graph: - dataset = self.as_dataset(**as_dataset_kwargs) - - def ds_iter(ds): - for el in dataset_utils.iterate_over_dataset(ds, graph=graph): - yield el - - wants_full_dataset = as_dataset_kwargs.get("batch_size") == -1 - if wants_full_dataset: - # as_dataset returned Tensors, possibly tupleized with - # as_supervised=True - if tf.executing_eagerly(): - return utils.map_nested(lambda t: t.numpy(), dataset, map_tuple=True) - else: - with utils.nogpu_session(graph=graph) as sess: - return sess.run(dataset) - else: - # as_dataset returned tf.data.Datasets - return utils.map_nested(ds_iter, dataset, dict_only=True) - def _get_data_dir(self, version=None): """Return the data directory of one dataset version. diff --git a/tensorflow_datasets/core/dataset_builder_test.py b/tensorflow_datasets/core/dataset_builder_test.py index 20e2dd1a272..4e1d97abf18 100644 --- a/tensorflow_datasets/core/dataset_builder_test.py +++ b/tensorflow_datasets/core/dataset_builder_test.py @@ -25,6 +25,7 @@ import tensorflow as tf from tensorflow_datasets.core import dataset_builder from tensorflow_datasets.core import dataset_info +from tensorflow_datasets.core import dataset_utils from tensorflow_datasets.core import features from tensorflow_datasets.core import registered from tensorflow_datasets.core import splits as splits_lib @@ -104,7 +105,8 @@ def test_shared_generator(self): splits_lib.Split.TRAIN, splits_lib.Split.TEST ] train_data, test_data = [ - [el["x"] for el in builder.as_numpy(split=split)] + [el["x"] for el in + dataset_utils.dataset_as_numpy(builder.as_dataset(split=split))] for split in splits_list ] @@ -123,12 +125,12 @@ def test_shared_generator(self): @tf.contrib.eager.run_test_in_graph_and_eager_modes() def test_load(self): with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir: - dataset = registered.load_numpy( + dataset = registered.load( name="dummy_dataset_shared_generator", data_dir=tmp_dir, download=True, split=splits_lib.Split.TRAIN) - data = list(dataset) + data = list(dataset_utils.dataset_as_numpy(dataset)) self.assertEqual(20, len(data)) self.assertLess(data[0]["x"], 30) @@ -197,7 +199,8 @@ def test_with_configs(self): splits_list = [splits_lib.Split.TRAIN, splits_lib.Split.TEST] for builder, incr in [(builder1, 1), (builder2, 2)]: train_data, test_data = [ - [el["x"] for el in builder.as_numpy(split=split)] + [el["x"] for el in + dataset_utils.dataset_as_numpy(builder.as_dataset(split=split))] for split in splits_list ] @@ -224,7 +227,8 @@ def setUp(self): @tf.contrib.eager.run_test_in_graph_and_eager_modes() def test_all_splits(self): - splits = self.builder.as_numpy(batch_size=-1) + splits = dataset_utils.dataset_as_numpy( + self.builder.as_dataset(batch_size=-1)) self.assertSetEqual(set(splits.keys()), set([splits_lib.Split.TRAIN, splits_lib.Split.TEST])) @@ -240,8 +244,8 @@ def test_all_splits(self): @tf.contrib.eager.run_test_in_graph_and_eager_modes() def test_with_batch_size(self): - items = list(self.builder.as_numpy( - split=splits_lib.Split.TRAIN + splits_lib.Split.TEST, batch_size=10)) + items = list(dataset_utils.dataset_as_numpy(self.builder.as_dataset( + split=splits_lib.Split.TRAIN + splits_lib.Split.TEST, batch_size=10))) # 3 batches of 10 self.assertEqual(3, len(items)) x1, x2, x3 = items[0]["x"], items[1]["x"], items[2]["x"] @@ -250,21 +254,10 @@ def test_with_batch_size(self): self.assertEqual(10, x3.shape[0]) self.assertEqual(sum(range(30)), int(x1.sum() + x2.sum() + x3.sum())) - @tf.contrib.eager.run_test_in_graph_and_eager_modes() - def test_as_numpy(self): - items = self.builder.as_numpy(split=splits_lib.Split.TRAIN, batch_size=-1) - self.assertEqual(items["x"].shape[0], 20) - self.assertLess(items["x"][0], 30) - - count = 0 - for _ in self.builder.as_numpy(split=splits_lib.Split.TRAIN): - count += 1 - self.assertEqual(count, 20) - @tf.contrib.eager.run_test_in_graph_and_eager_modes() def test_supervised_keys(self): - x, _ = self.builder.as_numpy( - split=splits_lib.Split.TRAIN, as_supervised=True, batch_size=-1) + x, _ = dataset_utils.dataset_as_numpy(self.builder.as_dataset( + split=splits_lib.Split.TRAIN, as_supervised=True, batch_size=-1)) self.assertEqual(x.shape[0], 20) diff --git a/tensorflow_datasets/core/dataset_info.py b/tensorflow_datasets/core/dataset_info.py index 31ff924ba90..7df2345dccb 100644 --- a/tensorflow_datasets/core/dataset_info.py +++ b/tensorflow_datasets/core/dataset_info.py @@ -428,7 +428,7 @@ def get_dataset_feature_statistics(builder, split): feature_to_min = {} feature_to_max = {} - for example in dataset_utils.iterate_over_dataset(dataset): + for example in dataset_utils.dataset_as_numpy(dataset): statistics.num_examples += 1 assert isinstance(example, dict) diff --git a/tensorflow_datasets/core/dataset_utils.py b/tensorflow_datasets/core/dataset_utils.py index 887b63b5fed..a3463f15b36 100644 --- a/tensorflow_datasets/core/dataset_utils.py +++ b/tensorflow_datasets/core/dataset_utils.py @@ -20,6 +20,7 @@ from __future__ import print_function import tensorflow as tf +from tensorflow_datasets.core import api_utils from tensorflow_datasets.core import utils @@ -87,19 +88,84 @@ def instruction_ds_to_file_ds(instruction): return dataset -def iterate_over_dataset(dataset, graph=None): - """Yields numpy elements of `tf.data.Dataset`.""" +def _eager_dataset_iterator(dataset): + for item in dataset: + flat = tf.contrib.framework.nest.flatten(item) + flat = [el.numpy() for el in flat] + yield tf.contrib.framework.nest.pack_sequence_as(item, flat) + + +def _graph_dataset_iterator(ds_item, graph=None): + with utils.nogpu_session(graph) as sess: + while True: + try: + yield sess.run(ds_item) + except tf.errors.OutOfRangeError: + break + + +@api_utils.disallow_positional_args(allowed=["dataset"]) +def dataset_as_numpy(dataset, graph=None): + """Converts a `tf.data.Dataset` to an iterable of NumPy arrays. + + `dataset_as_numpy` converts a possibly nested structure of `tf.data.Dataset`s + and `tf.Tensor`s to iterables of NumPy arrays and NumPy arrays, respectively. + + Args: + dataset: a possibly nested structure of `tf.data.Dataset`s and/or + `tf.Tensor`s. + graph: `tf.Graph`, optional, explicitly set the graph to use. + + Returns: + A structure matching `dataset` where `tf.data.Datset`s are converted to + generators of NumPy arrays and `tf.Tensor`s are converted to NumPy arrays. + """ + nested_ds = dataset + del dataset + + # Flatten + flat_ds = tf.contrib.framework.nest.flatten(nested_ds) + flat_np = [] + + # Type check for Tensors and Datasets + for ds_el in flat_ds: + types = [type(el) for el in flat_ds] + types = tf.contrib.framework.nest.pack_sequence_as(nested_ds, types) + if not isinstance(ds_el, (tf.Tensor, tf.data.Dataset)): + raise ValueError("Arguments to dataset_as_numpy must be tf.Tensors or " + "tf.data.Datasets. Got: %s" % types) + if tf.executing_eagerly(): - for item in dataset: - flat = tf.contrib.framework.nest.flatten(item) - flat = [el.numpy() for el in flat] - yield tf.contrib.framework.nest.pack_sequence_as(item, flat) + # Eager mode + for ds_el in flat_ds: + if isinstance(ds_el, tf.Tensor): + np_el = ds_el.numpy() + elif isinstance(ds_el, tf.data.Dataset): + np_el = _eager_dataset_iterator(ds_el) + else: + assert False + flat_np.append(np_el) else: + # Graph mode + + # First create necessary graph ops + ds_iters = [None] * len(flat_ds) with utils.maybe_with_graph(graph, create_if_none=False): - item = dataset.make_one_shot_iterator().get_next() + for i, ds_el in enumerate(flat_ds): + if isinstance(ds_el, tf.data.Dataset): + ds_iters[i] = ds_el.make_one_shot_iterator().get_next() + + # Then create NumPy items + # Shared session for tf.Tensor runs with utils.nogpu_session(graph) as sess: - while True: - try: - yield sess.run(item) - except tf.errors.OutOfRangeError: - break + for ds_iter, ds_el in zip(ds_iters, flat_ds): + if ds_iter is None: + # Tensor + np_el = sess.run(ds_el) + else: + # Dataset + np_el = _graph_dataset_iterator(ds_iter, graph) + flat_np.append(np_el) + + # Nest + return tf.contrib.framework.nest.pack_sequence_as(nested_ds, flat_np) diff --git a/tensorflow_datasets/core/dataset_utils_test.py b/tensorflow_datasets/core/dataset_utils_test.py new file mode 100644 index 00000000000..a7e8ae318da --- /dev/null +++ b/tensorflow_datasets/core/dataset_utils_test.py @@ -0,0 +1,124 @@ +# coding=utf-8 +# Copyright 2018 The TensorFlow Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for tensorflow_datasets.core.dataset_utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tensorflow_datasets.core import dataset_utils + + +def _create_dataset(rng): + return tf.data.Dataset.from_tensor_slices(list(rng)) + + +class DatasetAsNumPyTest(tf.test.TestCase): + + @tf.contrib.eager.run_test_in_graph_and_eager_modes() + def test_singleton_tensor(self): + t = tf.random_normal((10, 10)) + np_t = dataset_utils.dataset_as_numpy(t) + self.assertEqual((10, 10), np_t.shape) + self.assertEqual(np.float32, np_t.dtype) + + @tf.contrib.eager.run_test_in_graph_and_eager_modes() + def test_nested_tensors(self): + t1 = tf.random_normal((10, 10)) + t2 = tf.random_normal((10, 20)) + nest_tup = (t1, t2) + np_t1, np_t2 = dataset_utils.dataset_as_numpy(nest_tup) + self.assertEqual((10, 10), np_t1.shape) + self.assertEqual(np.float32, np_t1.dtype) + self.assertEqual((10, 20), np_t2.shape) + self.assertEqual(np.float32, np_t2.dtype) + + nest_dict = {"foo": t1, "bar": {"zoo": t2}} + np_nest_dict = dataset_utils.dataset_as_numpy(nest_dict) + np_t1 = np_nest_dict["foo"] + np_t2 = np_nest_dict["bar"]["zoo"] + self.assertEqual((10, 10), np_t1.shape) + self.assertEqual(np.float32, np_t1.dtype) + self.assertEqual((10, 20), np_t2.shape) + self.assertEqual(np.float32, np_t2.dtype) + + @tf.contrib.eager.run_test_in_graph_and_eager_modes() + def test_singleton_dataset(self): + ds = _create_dataset(range(10)) + np_ds = dataset_utils.dataset_as_numpy(ds) + self.assertEqual(list(range(10)), [int(el) for el in list(np_ds)]) + + def test_with_graph(self): + with tf.Graph().as_default() as g: + ds = _create_dataset(range(10)) + np_ds = dataset_utils.dataset_as_numpy(ds, graph=g) + self.assertEqual(list(range(10)), [int(el) for el in list(np_ds)]) + + @tf.contrib.eager.run_test_in_graph_and_eager_modes() + def test_singleton_dataset_with_nested_elements(self): + ds = _create_dataset(range(10)) + ds = ds.map(lambda el: {"a": el, "b": el + 1, "c": (el + 2, el + 3)}) + np_ds = dataset_utils.dataset_as_numpy(ds) + for i, el in enumerate(np_ds): + self.assertEqual(i, el["a"]) + self.assertEqual(i + 1, el["b"]) + self.assertEqual(i + 2, el["c"][0]) + self.assertEqual(i + 3, el["c"][1]) + + @tf.contrib.eager.run_test_in_graph_and_eager_modes() + def test_nested_dataset_sequential_access(self): + ds1 = _create_dataset(range(10)) + ds2 = _create_dataset(range(10, 20)) + np_ds = dataset_utils.dataset_as_numpy((ds1, {"a": ds2})) + np_ds1 = np_ds[0] + np_ds2 = np_ds[1]["a"] + + self.assertEqual(list(range(10)), [int(el) for el in list(np_ds1)]) + self.assertEqual(list(range(10, 20)), [int(el) for el in list(np_ds2)]) + + @tf.contrib.eager.run_test_in_graph_and_eager_modes() + def test_nested_dataset_simultaneous_access(self): + ds1 = _create_dataset(range(10)) + ds2 = _create_dataset(range(10, 20)) + np_ds = dataset_utils.dataset_as_numpy((ds1, {"a": ds2})) + np_ds1 = np_ds[0] + np_ds2 = np_ds[1]["a"] + + for i1, i2 in zip(np_ds1, np_ds2): + self.assertEqual(i2, int(i1) + 10) + + @tf.contrib.eager.run_test_in_graph_and_eager_modes() + def test_nested_dataset_nested_elements(self): + ds1 = _create_dataset(range(10)) + ds1 = ds1.map(lambda el: {"a": el, "b": el + 1, "c": (el + 2, el + 3)}) + ds2 = _create_dataset(range(10, 20)) + np_ds = dataset_utils.dataset_as_numpy((ds1, {"a": ds2})) + np_ds1 = np_ds[0] + np_ds2 = np_ds[1]["a"] + + for i, (el1, el2) in enumerate(zip(np_ds1, np_ds2)): + self.assertEqual(i + 10, el2) + self.assertEqual(i, el1["a"]) + self.assertEqual(i + 1, el1["b"]) + self.assertEqual(i + 2, el1["c"][0]) + self.assertEqual(i + 3, el1["c"][1]) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_datasets/core/registered.py b/tensorflow_datasets/core/registered.py index 13f0091b2f5..e5477c43200 100644 --- a/tensorflow_datasets/core/registered.py +++ b/tensorflow_datasets/core/registered.py @@ -135,58 +135,6 @@ def builder(name, **ctor_kwargs): raise -def _load(name, - split=None, - data_dir=None, - batch_size=1, - download=True, - as_numpy=False, - as_supervised=False, - with_info=False, - builder_kwargs=None, - download_and_prepare_kwargs=None, - as_dataset_kwargs=None): - """Shared implementation for `tfds.load` and `tfds.load_numpy`.""" - if data_dir is None: - data_dir = constants.DATA_DIR - builder_kwargs = builder_kwargs or {} - dbuilder = builder(name, data_dir=data_dir, **builder_kwargs) - if download: - download_and_prepare_kwargs = download_and_prepare_kwargs or {} - dbuilder.download_and_prepare(**download_and_prepare_kwargs) - - if as_dataset_kwargs is None: - as_dataset_kwargs = {} - as_dataset_kwargs = dict(as_dataset_kwargs) - as_dataset_kwargs["split"] = split - as_dataset_kwargs["as_supervised"] = as_supervised - as_dataset_kwargs["batch_size"] = batch_size - - if as_numpy: - ds = dbuilder.as_numpy(**as_dataset_kwargs) - else: - ds = dbuilder.as_dataset(**as_dataset_kwargs) - if with_info: - return ds, dbuilder.info - return ds - - -def load_numpy(**kwargs): - """`tfds.load` with NumPy generators/arrays instead of Datasets/Tensors. - - Uses `tfds.core.DatasetBuilder.as_numpy` instead of `as_dataset`. - - Args: - **kwargs: passed to `tfds.load`. - - Returns: - Generator(s) of NumPy arrays (or just the NumPy arrays if `batch_size=-1`). - If `split=None` (default), returns a `dict` with all the data splits. - """ - kwargs["as_numpy"] = True - return _load(**kwargs) - - @api_utils.disallow_positional_args(allowed=["name"]) def load(name, split=None, @@ -272,17 +220,25 @@ def load(name, will return a tuple (ds, ds_info) containing the dataset info (version, features, splits, num_examples,...). """ - return _load( - name=name, - split=split, - data_dir=data_dir, - batch_size=batch_size, - download=download, - as_supervised=as_supervised, - with_info=with_info, - builder_kwargs=builder_kwargs, - download_and_prepare_kwargs=download_and_prepare_kwargs, - as_dataset_kwargs=as_dataset_kwargs) + if data_dir is None: + data_dir = constants.DATA_DIR + builder_kwargs = builder_kwargs or {} + dbuilder = builder(name, data_dir=data_dir, **builder_kwargs) + if download: + download_and_prepare_kwargs = download_and_prepare_kwargs or {} + dbuilder.download_and_prepare(**download_and_prepare_kwargs) + + if as_dataset_kwargs is None: + as_dataset_kwargs = {} + as_dataset_kwargs = dict(as_dataset_kwargs) + as_dataset_kwargs["split"] = split + as_dataset_kwargs["as_supervised"] = as_supervised + as_dataset_kwargs["batch_size"] = batch_size + + ds = dbuilder.as_dataset(**as_dataset_kwargs) + if with_info: + return ds, dbuilder.info + return ds def _dataset_name_and_kwargs_from_name_str(name_str): diff --git a/tensorflow_datasets/core/splits_test.py b/tensorflow_datasets/core/splits_test.py index cd2ad06ff4b..c32106baf35 100644 --- a/tensorflow_datasets/core/splits_test.py +++ b/tensorflow_datasets/core/splits_test.py @@ -69,7 +69,8 @@ def _generate_examples(self, data): } def values(self, split): - return [int(v["value"]) for v in self.as_numpy(split=split)] + return [int(v["value"]) for v in + tfds.dataset_as_numpy(self.as_dataset(split=split))] class SplitsUnitTest(tf.test.TestCase): diff --git a/tensorflow_datasets/core/test_utils.py b/tensorflow_datasets/core/test_utils.py index 78734a8af96..0a2082b0801 100644 --- a/tensorflow_datasets/core/test_utils.py +++ b/tensorflow_datasets/core/test_utils.py @@ -223,7 +223,7 @@ def features_encode_decode(features_dict, example, as_tensor=False): dataset = dataset.map(features_dict.decode_example) if not as_tensor: # Evaluate to numpy array - for el in dataset_utils.iterate_over_dataset(dataset): + for el in dataset_utils.dataset_as_numpy(dataset): return el else: if tf.executing_eagerly(): diff --git a/tensorflow_datasets/public_api.py b/tensorflow_datasets/public_api.py index f1c0782a03a..0a9254abe49 100644 --- a/tensorflow_datasets/public_api.py +++ b/tensorflow_datasets/public_api.py @@ -22,6 +22,7 @@ from tensorflow_datasets.core import features from tensorflow_datasets.core import file_format_adapter as file_adapter from tensorflow_datasets.core import units +from tensorflow_datasets.core.dataset_utils import dataset_as_numpy from tensorflow_datasets.core.download import GenerateMode from tensorflow_datasets.core.registered import builder from tensorflow_datasets.core.registered import list_builders @@ -31,6 +32,7 @@ __all__ = [ "core", + "dataset_as_numpy", "download", "features", "file_adapter", diff --git a/tensorflow_datasets/testing/dataset_builder_testing.py b/tensorflow_datasets/testing/dataset_builder_testing.py index 95208f8e7f4..75289dd8068 100644 --- a/tensorflow_datasets/testing/dataset_builder_testing.py +++ b/tensorflow_datasets/testing/dataset_builder_testing.py @@ -29,6 +29,7 @@ from tensorflow_datasets.core import dataset_builder from tensorflow_datasets.core import dataset_info +from tensorflow_datasets.core import dataset_utils from tensorflow_datasets.core import registered from tensorflow_datasets.core import test_utils from tensorflow_datasets.core.utils import tf_utils @@ -226,7 +227,8 @@ def _assertAsDataset(self, builder): dataset = builder.as_dataset(split=split_name) compare_shapes_and_types(builder.info.features.get_tensor_info(), dataset.output_types, dataset.output_shapes) - examples = list(builder.as_numpy(split=split_name)) + examples = list(dataset_utils.dataset_as_numpy( + builder.as_dataset(split=split_name))) split_to_checksums[split_name] = set(checksum(rec) for rec in examples) self.assertLen(examples, expected_examples_number) for (split1, hashes1), (split2, hashes2) in itertools.combinations(