Skip to content

Commit

Permalink
Use semantic feature names and add DatasetInfo.supervised_keys
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 220899366
  • Loading branch information
Ryan Sepassi authored and Copybara-Service committed Nov 10, 2018
1 parent 4bc312d commit 3003377
Show file tree
Hide file tree
Showing 12 changed files with 93 additions and 37 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dataset = tfds.load(name="mnist", split=tfds.Split.TRAIN)
# Build your input pipeline
dataset = dataset.shuffle(1000).batch(128).prefetch(tf.contrib.data.AUTOTUNE)
features = dataset.make_oneshot_iterator().get_next()
image, label = features["input"], features["target"]
image, label = features["image"], features["label"]
```

### `DatasetBuilder`
Expand Down Expand Up @@ -65,7 +65,7 @@ import tensorflow_datasets as tfds
mnist_builder = tfds.builder("mnist")
mnist_builder.download_and_prepare()
for element in mnist_builder.numpy_iterator(split=tfds.Split.TRAIN):
numpy_image, numpy_label = element["input"], element["target"]
numpy_image, numpy_label = element["image"], element["label"]
```

Note that the library still requires `tensorflow` as an internal dependency.
Expand Down
2 changes: 1 addition & 1 deletion docs/_index.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"# Build your input pipeline\n",
"dataset = dataset.shuffle(1024).batch(32).prefetch(tf.contrib.data.AUTOTUNE)\n",
"for features in dataset.take(1):\n",
" image, label = features[\"input\"], features[\"target\"]"
" image, label = features[\"image\"], features[\"label\"]"
]
}
],
Expand Down
2 changes: 1 addition & 1 deletion docs/_index.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ landing_page:
# Build your input pipeline
dataset = dataset.shuffle(1024).batch(32).prefetch(tf.contrib.data.AUTOTUNE)
for features in dataset.take(1):
image, label = features["input"], features["target"]
image, label = features["image"], features["label"]
</pre>
{% dynamic if request.tld != 'cn' %}
Expand Down
8 changes: 4 additions & 4 deletions docs/overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@
"output_type": "execute_result",
"data": {
"text/plain": [
"<MapDataset shapes: {input: <unknown>, target: ()}, types: {input: tf.uint8, target: tf.int64}>"
"<MapDataset shapes: {image: <unknown>, label: ()}, types: {image: tf.uint8, label: tf.int64}>"
]
},
"metadata": {
Expand All @@ -278,7 +278,7 @@
"source": [
"## Feature dictionaries\n",
"\n",
"All `tfds` datasets contain feature dictionaries mapping feature names to Tensor values. A typical dataset, like MNIST, will have 2 keys: `\"input\"` and `\"target\"`. Below we inspect a single record."
"All `tfds` datasets contain feature dictionaries mapping feature names to Tensor values. A typical dataset, like MNIST, will have 2 keys: `\"image\"` and `\"label\"`. Below we inspect a single record."
]
},
{
Expand All @@ -294,7 +294,7 @@
"cell_type": "code",
"source": [
"mnist_example, = mnist_train.take(1)\n",
"image, label = mnist_example[\"input\"], mnist_example[\"target\"]\n",
"image, label = mnist_example[\"image\"], mnist_example[\"label\"]\n",
"\n",
"plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap('gray'))\n",
"print(\"Label: %d\" % label.numpy())"
Expand Down Expand Up @@ -364,7 +364,7 @@
"output_type": "execute_result",
"data": {
"text/plain": [
"<MapDataset shapes: {input: <unknown>, target: ()}, types: {input: tf.uint8, target: tf.int64}>"
"<MapDataset shapes: {image: <unknown>, label: ()}, types: {image: tf.uint8, label: tf.int64}>"
]
},
"metadata": {
Expand Down
23 changes: 20 additions & 3 deletions tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class DatasetBuilder(object):
# Use tf.contrib.data.AUTOTUNE to automatically optimize the input pipeline
train_dataset = train_dataset.prefetch(tf.contrib.data.AUTOTUNE)
features = train_dataset.make_one_shot_iterator().get_next()
image, label = features['input'], features['target']
image, label = features['image'], features['label']
```
"""

Expand Down Expand Up @@ -154,7 +154,10 @@ def download_and_prepare(self, cache_dir=None, dl_manager=None):
# TODO(rsepassi): Make it easy to further shard the TRAIN data (e.g. for
# synthetic VALIDATION splits).
@api_utils.disallow_positional_args
def as_dataset(self, split, shuffle_files=None):
def as_dataset(self,
split,
shuffle_files=None,
as_supervised=False):
"""Constructs a `tf.data.Dataset`.
Callers must pass arguments as keyword arguments.
Expand All @@ -165,6 +168,11 @@ def as_dataset(self, split, shuffle_files=None):
split: `tfds.Split`, which subset of the data to read.
shuffle_files: `bool` (optional), whether to shuffle the input files.
Defaults to `True` if `split == tfds.Split.TRAIN` and `False` otherwise.
as_supervised: `bool`, if `True`, the returned `tf.data.Dataset`
will have a 2-tuple structure `(input, label)` according to
`builder.info.supervised_keys`. If `False`, the default,
the returned `tf.data.Dataset` will have a dictionary with all the
features.
Returns:
`tf.data.Dataset`
Expand All @@ -175,7 +183,16 @@ def as_dataset(self, split, shuffle_files=None):
"dataset_builder.download_and_prepare(), or pass download=True to "
"tfds.load() before trying to access the tf.data.Dataset object."
) % (self.name, self._data_dir_root))
return self._as_dataset(split=split, shuffle_files=shuffle_files)
dataset = self._as_dataset(split=split, shuffle_files=shuffle_files)
if as_supervised:
if not self.info.supervised_keys:
raise ValueError(
"as_supervised=True but %s does not support a supervised "
"(input, label) structure." % self.name)
input_f, target_f = self.info.supervised_keys
dataset = dataset.map(lambda fs: (fs[input_f], fs[target_f]))
dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)
return dataset

def numpy_iterator(self, **as_dataset_kwargs):
"""Generates numpy elements from the given `tfds.Split`.
Expand Down
36 changes: 29 additions & 7 deletions tensorflow_datasets/core/dataset_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def _split_generators(self, dl_manager):
def _info(self):
return dataset_info.DatasetInfo(
specs=features.SpecDict({"x": tf.int64}),
supervised_keys=("x", "x"),
)

def _generate_samples(self):
Expand Down Expand Up @@ -92,15 +93,36 @@ def test_load(self):
split=splits.Split.TRAIN)
data = list(dataset)
self.assertEqual(20, len(data))
self.assertLess(data[0]["x"], 30)


class DatasetBuilderReadTest(tf.test.TestCase):

@classmethod
def setUpClass(cls):
cls._tfds_tmp_dir = test_utils.make_tmp_dir()
builder = DummyDatasetSharedGenerator(data_dir=cls._tfds_tmp_dir)
builder.download_and_prepare()

@classmethod
def tearDownClass(cls):
test_utils.rm_tmp_dir(cls._tfds_tmp_dir)

def test_numpy_iterator(self):
with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
builder = DummyDatasetSharedGenerator(data_dir=tmp_dir)
builder.download_and_prepare()
items = []
for item in builder.numpy_iterator(split=splits.Split.TRAIN):
items.append(item)
self.assertEqual(20, len(items))
builder = DummyDatasetSharedGenerator(data_dir=self._tfds_tmp_dir)
items = []
for item in builder.numpy_iterator(split=splits.Split.TRAIN):
items.append(item)
self.assertEqual(20, len(items))
self.assertLess(items[0]["x"], 30)

def test_supervised_keys(self):
builder = DummyDatasetSharedGenerator(data_dir=self._tfds_tmp_dir)
for item in builder.numpy_iterator(
split=splits.Split.TRAIN, as_supervised=True):
self.assertIsInstance(item, tuple)
self.assertEqual(len(item), 2)
break


if __name__ == "__main__":
Expand Down
13 changes: 11 additions & 2 deletions tensorflow_datasets/core/dataset_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,23 +66,32 @@ class DatasetInfo(object):
"""

@api_utils.disallow_positional_args
def __init__(self, specs):
def __init__(self, specs, supervised_keys=None):
"""Constructor of the DatasetInfo.
Args:
specs: (`tfds.features.SpecDict`) Information on the feature dict of the
`tf.data.Dataset()` object from the `builder.as_dataset()` method.
supervised_keys: (`tuple`) Specifies the input feature and the
label for supervised learning, if applicable for the dataset.
"""
self._specs = specs
self._splits = splits.SplitDict()
self._supervised_keys = supervised_keys
if supervised_keys is not None:
assert isinstance(supervised_keys, tuple)
assert len(supervised_keys) == 2
# TODO(pierrot): Move SIZE here
# TODO(afrozm): Should add other metadata here (num samples, hash,...)

@property
def specs(self):
return self._specs

@property
def supervised_keys(self):
return self._supervised_keys

@property
def splits(self):
return self._splits
Expand Down
7 changes: 7 additions & 0 deletions tensorflow_datasets/core/registered.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def load(name,
split,
data_dir=None,
download=True,
as_supervised=False,
as_dataset_kwargs=None):
"""Loads the given `tfds.Split` as a `tf.data.Dataset`.
Expand Down Expand Up @@ -135,6 +136,11 @@ def load(name,
expected to be in `data_dir`. If `True` and the data is already in
`data_dir`, `download_and_prepare` is a no-op.
Defaults to `True`.
as_supervised: `bool`, if `True`, the returned `tf.data.Dataset`
will have a 2-tuple structure `(input, label)` according to
`builder.info.supervised_keys`. If `False`, the default,
the returned `tf.data.Dataset` will have a dictionary with all the
features.
as_dataset_kwargs: `dict` (optional), keyword arguments passed to
`tfds.DatasetBuilder.as_dataset`. `split` will be passed through by
default.
Expand All @@ -152,6 +158,7 @@ def load(name,
as_dataset_kwargs = {}
as_dataset_kwargs = dict(as_dataset_kwargs)
as_dataset_kwargs["split"] = split
as_dataset_kwargs["as_supervised"] = as_supervised

return dbuilder.as_dataset(**as_dataset_kwargs)

Expand Down
4 changes: 1 addition & 3 deletions tensorflow_datasets/core/registered_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,9 @@ def test_load(self):
download=False, as_dataset_kwargs=as_dataset_kwargs)
self.assertTrue(builder.as_dataset_called)
self.assertFalse(builder.download_called)
print(as_dataset_kwargs)
print(builder.as_dataset_kwargs)
self.assertEqual(splits.Split.TEST,
builder.as_dataset_kwargs.pop("split"))
print(builder.as_dataset_kwargs)
self.assertFalse(builder.as_dataset_kwargs.pop("as_supervised"))
self.assertEqual(builder.as_dataset_kwargs, as_dataset_kwargs)
self.assertEqual(dict(data_dir=data_dir, k1=1), builder.kwargs)

Expand Down
18 changes: 10 additions & 8 deletions tensorflow_datasets/image/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ def _info(self):
cifar_shape = (_CIFAR_IMAGE_SIZE, _CIFAR_IMAGE_SIZE, 3)
return dataset_info.DatasetInfo(
specs=features.SpecDict({
"input": features.Image(shape=cifar_shape),
"target": tf.int64, # Could replace by features.Label()
"image": features.Image(shape=cifar_shape),
"label": tf.int64, # Could replace by features.Label()
}),
supervised_keys=("image", "label"),
)

@property
Expand Down Expand Up @@ -139,8 +140,8 @@ def _generate_samples(self, filepaths):
if len(label) == 1:
label = label[self._cifar_info.label_keys[0]]
yield self.info.specs.encode_sample({
"input": image,
"target": label,
"image": image,
"label": label,
})


Expand All @@ -152,7 +153,7 @@ def __init__(self, use_coarse_labels=False, **kwargs):
Args:
use_coarse_labels (bool): whether to set the coarse labels or the fine
labels as "target". Note that in either case, both features will be
labels as "label". Note that in either case, both features will be
present in the features dictionary as "fine_label" and "coarse_label".
Note also that this does NOT affect the data on disk and is only used in
the `tf.data.Dataset` input pipeline.
Expand All @@ -177,12 +178,13 @@ def _info(self):
label_to_use = "coarse_labels" if self._use_coarse_labels else "fine_labels"
return dataset_info.DatasetInfo(
specs=features.SpecDict({
"input": features.Image(shape=cifar_shape),
"target": features.OneOf(choice=label_to_use, feature_dict={
"image": features.Image(shape=cifar_shape),
"label": features.OneOf(choice=label_to_use, feature_dict={
"coarse_labels": tf.int64,
"fine_labels": tf.int64,
}),
}),
supervised_keys=("image", "label"),
)


Expand All @@ -205,5 +207,5 @@ class CifarInfo(collections.namedtuple("_CifarInfo", [
test_files (list<str>): name of test files within `prefix`.
label_keys (list<str>): names of the label keys in the data. If longer than
1, provide `out_label_keys` to specify output names in feature
dictionaries. Otherwise will use "target".
dictionaries. Otherwise will use "label".
"""
4 changes: 2 additions & 2 deletions tensorflow_datasets/image/cifar_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class Cifar10Test(dataset_builder_testing.TestCase):
"test": 2, # See testing/generate_cifar10_like_sample.py
}
SPEC = {
"target": (tf.int64, ()),
"input": (tf.uint8, (32, 32, 3)),
"label": (tf.int64, ()),
"image": (tf.uint8, (32, 32, 3)),
}


Expand Down
9 changes: 5 additions & 4 deletions tensorflow_datasets/image/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ def _info(self):
mnist_shape = (_MNIST_IMAGE_SIZE, _MNIST_IMAGE_SIZE, 1)
return dataset_info.DatasetInfo(
specs=features.SpecDict({
"input": features.Image(shape=mnist_shape),
"target": tf.int64,
"image": features.Image(shape=mnist_shape),
"label": tf.int64,
}),
supervised_keys=("image", "label"),
)

def _split_generators(self, dl_manager):
Expand Down Expand Up @@ -111,8 +112,8 @@ def _generate_samples(self, num_examples, data_path, label_path):

for image, label in data:
yield self.info.specs.encode_sample({
"input": image,
"target": label,
"image": image,
"label": label,
})


Expand Down

0 comments on commit 3003377

Please sign in to comment.