From 30033778289a792ca124640f235120b2c755e6da Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Fri, 9 Nov 2018 19:26:35 -0800 Subject: [PATCH] Use semantic feature names and add DatasetInfo.supervised_keys PiperOrigin-RevId: 220899366 --- README.md | 4 +-- docs/_index.ipynb | 2 +- docs/_index.yaml | 2 +- docs/overview.ipynb | 8 ++--- tensorflow_datasets/core/dataset_builder.py | 23 ++++++++++-- .../core/dataset_builder_test.py | 36 +++++++++++++++---- tensorflow_datasets/core/dataset_info.py | 13 +++++-- tensorflow_datasets/core/registered.py | 7 ++++ tensorflow_datasets/core/registered_test.py | 4 +-- tensorflow_datasets/image/cifar.py | 18 +++++----- tensorflow_datasets/image/cifar_test.py | 4 +-- tensorflow_datasets/image/mnist.py | 9 ++--- 12 files changed, 93 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index a244a1b94d3..7c44fa86ea1 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ dataset = tfds.load(name="mnist", split=tfds.Split.TRAIN) # Build your input pipeline dataset = dataset.shuffle(1000).batch(128).prefetch(tf.contrib.data.AUTOTUNE) features = dataset.make_oneshot_iterator().get_next() -image, label = features["input"], features["target"] +image, label = features["image"], features["label"] ``` ### `DatasetBuilder` @@ -65,7 +65,7 @@ import tensorflow_datasets as tfds mnist_builder = tfds.builder("mnist") mnist_builder.download_and_prepare() for element in mnist_builder.numpy_iterator(split=tfds.Split.TRAIN): - numpy_image, numpy_label = element["input"], element["target"] + numpy_image, numpy_label = element["image"], element["label"] ``` Note that the library still requires `tensorflow` as an internal dependency. diff --git a/docs/_index.ipynb b/docs/_index.ipynb index 9b46b77b825..2d9dcea1b4e 100644 --- a/docs/_index.ipynb +++ b/docs/_index.ipynb @@ -48,7 +48,7 @@ "# Build your input pipeline\n", "dataset = dataset.shuffle(1024).batch(32).prefetch(tf.contrib.data.AUTOTUNE)\n", "for features in dataset.take(1):\n", - " image, label = features[\"input\"], features[\"target\"]" + " image, label = features[\"image\"], features[\"label\"]" ] } ], diff --git a/docs/_index.yaml b/docs/_index.yaml index 2f79a97bd8a..53596f5774f 100644 --- a/docs/_index.yaml +++ b/docs/_index.yaml @@ -36,7 +36,7 @@ landing_page: # Build your input pipeline dataset = dataset.shuffle(1024).batch(32).prefetch(tf.contrib.data.AUTOTUNE) for features in dataset.take(1): - image, label = features["input"], features["target"] + image, label = features["image"], features["label"] {% dynamic if request.tld != 'cn' %} diff --git a/docs/overview.ipynb b/docs/overview.ipynb index 8255d6c90cc..0740fc567fc 100644 --- a/docs/overview.ipynb +++ b/docs/overview.ipynb @@ -259,7 +259,7 @@ "output_type": "execute_result", "data": { "text/plain": [ - ", target: ()}, types: {input: tf.uint8, target: tf.int64}>" + ", label: ()}, types: {image: tf.uint8, label: tf.int64}>" ] }, "metadata": { @@ -278,7 +278,7 @@ "source": [ "## Feature dictionaries\n", "\n", - "All `tfds` datasets contain feature dictionaries mapping feature names to Tensor values. A typical dataset, like MNIST, will have 2 keys: `\"input\"` and `\"target\"`. Below we inspect a single record." + "All `tfds` datasets contain feature dictionaries mapping feature names to Tensor values. A typical dataset, like MNIST, will have 2 keys: `\"image\"` and `\"label\"`. Below we inspect a single record." ] }, { @@ -294,7 +294,7 @@ "cell_type": "code", "source": [ "mnist_example, = mnist_train.take(1)\n", - "image, label = mnist_example[\"input\"], mnist_example[\"target\"]\n", + "image, label = mnist_example[\"image\"], mnist_example[\"label\"]\n", "\n", "plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap('gray'))\n", "print(\"Label: %d\" % label.numpy())" @@ -364,7 +364,7 @@ "output_type": "execute_result", "data": { "text/plain": [ - ", target: ()}, types: {input: tf.uint8, target: tf.int64}>" + ", label: ()}, types: {image: tf.uint8, label: tf.int64}>" ] }, "metadata": { diff --git a/tensorflow_datasets/core/dataset_builder.py b/tensorflow_datasets/core/dataset_builder.py index 285c0383310..f00edf00f24 100644 --- a/tensorflow_datasets/core/dataset_builder.py +++ b/tensorflow_datasets/core/dataset_builder.py @@ -64,7 +64,7 @@ class DatasetBuilder(object): # Use tf.contrib.data.AUTOTUNE to automatically optimize the input pipeline train_dataset = train_dataset.prefetch(tf.contrib.data.AUTOTUNE) features = train_dataset.make_one_shot_iterator().get_next() - image, label = features['input'], features['target'] + image, label = features['image'], features['label'] ``` """ @@ -154,7 +154,10 @@ def download_and_prepare(self, cache_dir=None, dl_manager=None): # TODO(rsepassi): Make it easy to further shard the TRAIN data (e.g. for # synthetic VALIDATION splits). @api_utils.disallow_positional_args - def as_dataset(self, split, shuffle_files=None): + def as_dataset(self, + split, + shuffle_files=None, + as_supervised=False): """Constructs a `tf.data.Dataset`. Callers must pass arguments as keyword arguments. @@ -165,6 +168,11 @@ def as_dataset(self, split, shuffle_files=None): split: `tfds.Split`, which subset of the data to read. shuffle_files: `bool` (optional), whether to shuffle the input files. Defaults to `True` if `split == tfds.Split.TRAIN` and `False` otherwise. + as_supervised: `bool`, if `True`, the returned `tf.data.Dataset` + will have a 2-tuple structure `(input, label)` according to + `builder.info.supervised_keys`. If `False`, the default, + the returned `tf.data.Dataset` will have a dictionary with all the + features. Returns: `tf.data.Dataset` @@ -175,7 +183,16 @@ def as_dataset(self, split, shuffle_files=None): "dataset_builder.download_and_prepare(), or pass download=True to " "tfds.load() before trying to access the tf.data.Dataset object." ) % (self.name, self._data_dir_root)) - return self._as_dataset(split=split, shuffle_files=shuffle_files) + dataset = self._as_dataset(split=split, shuffle_files=shuffle_files) + if as_supervised: + if not self.info.supervised_keys: + raise ValueError( + "as_supervised=True but %s does not support a supervised " + "(input, label) structure." % self.name) + input_f, target_f = self.info.supervised_keys + dataset = dataset.map(lambda fs: (fs[input_f], fs[target_f])) + dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) + return dataset def numpy_iterator(self, **as_dataset_kwargs): """Generates numpy elements from the given `tfds.Split`. diff --git a/tensorflow_datasets/core/dataset_builder_test.py b/tensorflow_datasets/core/dataset_builder_test.py index d738157cd32..7b352488402 100644 --- a/tensorflow_datasets/core/dataset_builder_test.py +++ b/tensorflow_datasets/core/dataset_builder_test.py @@ -46,6 +46,7 @@ def _split_generators(self, dl_manager): def _info(self): return dataset_info.DatasetInfo( specs=features.SpecDict({"x": tf.int64}), + supervised_keys=("x", "x"), ) def _generate_samples(self): @@ -92,15 +93,36 @@ def test_load(self): split=splits.Split.TRAIN) data = list(dataset) self.assertEqual(20, len(data)) + self.assertLess(data[0]["x"], 30) + + +class DatasetBuilderReadTest(tf.test.TestCase): + + @classmethod + def setUpClass(cls): + cls._tfds_tmp_dir = test_utils.make_tmp_dir() + builder = DummyDatasetSharedGenerator(data_dir=cls._tfds_tmp_dir) + builder.download_and_prepare() + + @classmethod + def tearDownClass(cls): + test_utils.rm_tmp_dir(cls._tfds_tmp_dir) def test_numpy_iterator(self): - with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir: - builder = DummyDatasetSharedGenerator(data_dir=tmp_dir) - builder.download_and_prepare() - items = [] - for item in builder.numpy_iterator(split=splits.Split.TRAIN): - items.append(item) - self.assertEqual(20, len(items)) + builder = DummyDatasetSharedGenerator(data_dir=self._tfds_tmp_dir) + items = [] + for item in builder.numpy_iterator(split=splits.Split.TRAIN): + items.append(item) + self.assertEqual(20, len(items)) + self.assertLess(items[0]["x"], 30) + + def test_supervised_keys(self): + builder = DummyDatasetSharedGenerator(data_dir=self._tfds_tmp_dir) + for item in builder.numpy_iterator( + split=splits.Split.TRAIN, as_supervised=True): + self.assertIsInstance(item, tuple) + self.assertEqual(len(item), 2) + break if __name__ == "__main__": diff --git a/tensorflow_datasets/core/dataset_info.py b/tensorflow_datasets/core/dataset_info.py index f10da5b5618..1ad716f0381 100644 --- a/tensorflow_datasets/core/dataset_info.py +++ b/tensorflow_datasets/core/dataset_info.py @@ -66,16 +66,21 @@ class DatasetInfo(object): """ @api_utils.disallow_positional_args - def __init__(self, specs): + def __init__(self, specs, supervised_keys=None): """Constructor of the DatasetInfo. Args: specs: (`tfds.features.SpecDict`) Information on the feature dict of the `tf.data.Dataset()` object from the `builder.as_dataset()` method. - + supervised_keys: (`tuple`) Specifies the input feature and the + label for supervised learning, if applicable for the dataset. """ self._specs = specs self._splits = splits.SplitDict() + self._supervised_keys = supervised_keys + if supervised_keys is not None: + assert isinstance(supervised_keys, tuple) + assert len(supervised_keys) == 2 # TODO(pierrot): Move SIZE here # TODO(afrozm): Should add other metadata here (num samples, hash,...) @@ -83,6 +88,10 @@ def __init__(self, specs): def specs(self): return self._specs + @property + def supervised_keys(self): + return self._supervised_keys + @property def splits(self): return self._splits diff --git a/tensorflow_datasets/core/registered.py b/tensorflow_datasets/core/registered.py index 55bccf70444..3f638d7ebff 100644 --- a/tensorflow_datasets/core/registered.py +++ b/tensorflow_datasets/core/registered.py @@ -108,6 +108,7 @@ def load(name, split, data_dir=None, download=True, + as_supervised=False, as_dataset_kwargs=None): """Loads the given `tfds.Split` as a `tf.data.Dataset`. @@ -135,6 +136,11 @@ def load(name, expected to be in `data_dir`. If `True` and the data is already in `data_dir`, `download_and_prepare` is a no-op. Defaults to `True`. + as_supervised: `bool`, if `True`, the returned `tf.data.Dataset` + will have a 2-tuple structure `(input, label)` according to + `builder.info.supervised_keys`. If `False`, the default, + the returned `tf.data.Dataset` will have a dictionary with all the + features. as_dataset_kwargs: `dict` (optional), keyword arguments passed to `tfds.DatasetBuilder.as_dataset`. `split` will be passed through by default. @@ -152,6 +158,7 @@ def load(name, as_dataset_kwargs = {} as_dataset_kwargs = dict(as_dataset_kwargs) as_dataset_kwargs["split"] = split + as_dataset_kwargs["as_supervised"] = as_supervised return dbuilder.as_dataset(**as_dataset_kwargs) diff --git a/tensorflow_datasets/core/registered_test.py b/tensorflow_datasets/core/registered_test.py index 50ae1e80c54..a0e7c9d2f8d 100644 --- a/tensorflow_datasets/core/registered_test.py +++ b/tensorflow_datasets/core/registered_test.py @@ -89,11 +89,9 @@ def test_load(self): download=False, as_dataset_kwargs=as_dataset_kwargs) self.assertTrue(builder.as_dataset_called) self.assertFalse(builder.download_called) - print(as_dataset_kwargs) - print(builder.as_dataset_kwargs) self.assertEqual(splits.Split.TEST, builder.as_dataset_kwargs.pop("split")) - print(builder.as_dataset_kwargs) + self.assertFalse(builder.as_dataset_kwargs.pop("as_supervised")) self.assertEqual(builder.as_dataset_kwargs, as_dataset_kwargs) self.assertEqual(dict(data_dir=data_dir, k1=1), builder.kwargs) diff --git a/tensorflow_datasets/image/cifar.py b/tensorflow_datasets/image/cifar.py index 742da742c64..6544b87c981 100644 --- a/tensorflow_datasets/image/cifar.py +++ b/tensorflow_datasets/image/cifar.py @@ -61,9 +61,10 @@ def _info(self): cifar_shape = (_CIFAR_IMAGE_SIZE, _CIFAR_IMAGE_SIZE, 3) return dataset_info.DatasetInfo( specs=features.SpecDict({ - "input": features.Image(shape=cifar_shape), - "target": tf.int64, # Could replace by features.Label() + "image": features.Image(shape=cifar_shape), + "label": tf.int64, # Could replace by features.Label() }), + supervised_keys=("image", "label"), ) @property @@ -139,8 +140,8 @@ def _generate_samples(self, filepaths): if len(label) == 1: label = label[self._cifar_info.label_keys[0]] yield self.info.specs.encode_sample({ - "input": image, - "target": label, + "image": image, + "label": label, }) @@ -152,7 +153,7 @@ def __init__(self, use_coarse_labels=False, **kwargs): Args: use_coarse_labels (bool): whether to set the coarse labels or the fine - labels as "target". Note that in either case, both features will be + labels as "label". Note that in either case, both features will be present in the features dictionary as "fine_label" and "coarse_label". Note also that this does NOT affect the data on disk and is only used in the `tf.data.Dataset` input pipeline. @@ -177,12 +178,13 @@ def _info(self): label_to_use = "coarse_labels" if self._use_coarse_labels else "fine_labels" return dataset_info.DatasetInfo( specs=features.SpecDict({ - "input": features.Image(shape=cifar_shape), - "target": features.OneOf(choice=label_to_use, feature_dict={ + "image": features.Image(shape=cifar_shape), + "label": features.OneOf(choice=label_to_use, feature_dict={ "coarse_labels": tf.int64, "fine_labels": tf.int64, }), }), + supervised_keys=("image", "label"), ) @@ -205,5 +207,5 @@ class CifarInfo(collections.namedtuple("_CifarInfo", [ test_files (list): name of test files within `prefix`. label_keys (list): names of the label keys in the data. If longer than 1, provide `out_label_keys` to specify output names in feature - dictionaries. Otherwise will use "target". + dictionaries. Otherwise will use "label". """ diff --git a/tensorflow_datasets/image/cifar_test.py b/tensorflow_datasets/image/cifar_test.py index 0dd435506cd..b4682c1054d 100644 --- a/tensorflow_datasets/image/cifar_test.py +++ b/tensorflow_datasets/image/cifar_test.py @@ -31,8 +31,8 @@ class Cifar10Test(dataset_builder_testing.TestCase): "test": 2, # See testing/generate_cifar10_like_sample.py } SPEC = { - "target": (tf.int64, ()), - "input": (tf.uint8, (32, 32, 3)), + "label": (tf.int64, ()), + "image": (tf.uint8, (32, 32, 3)), } diff --git a/tensorflow_datasets/image/mnist.py b/tensorflow_datasets/image/mnist.py index 742cfc26235..0db7b52e4c8 100644 --- a/tensorflow_datasets/image/mnist.py +++ b/tensorflow_datasets/image/mnist.py @@ -53,9 +53,10 @@ def _info(self): mnist_shape = (_MNIST_IMAGE_SIZE, _MNIST_IMAGE_SIZE, 1) return dataset_info.DatasetInfo( specs=features.SpecDict({ - "input": features.Image(shape=mnist_shape), - "target": tf.int64, + "image": features.Image(shape=mnist_shape), + "label": tf.int64, }), + supervised_keys=("image", "label"), ) def _split_generators(self, dl_manager): @@ -111,8 +112,8 @@ def _generate_samples(self, num_examples, data_path, label_path): for image, label in data: yield self.info.specs.encode_sample({ - "input": image, - "target": label, + "image": image, + "label": label, })