Skip to content

Commit

Permalink
Merge pull request #198 from us:add-emnist-dataset
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 238280883
  • Loading branch information
copybara-github committed Mar 13, 2019
2 parents c8c00f0 + 7699041 commit a266ec3
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 34 deletions.
1 change: 1 addition & 0 deletions tensorflow_datasets/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from tensorflow_datasets.image.image_folder import ImageLabelFolder
from tensorflow_datasets.image.imagenet import Imagenet2012
from tensorflow_datasets.image.lsun import Lsun
from tensorflow_datasets.image.mnist import EMNIST
from tensorflow_datasets.image.mnist import FashionMNIST
from tensorflow_datasets.image.mnist import KMNIST
from tensorflow_datasets.image.mnist import MNIST
Expand Down
232 changes: 199 additions & 33 deletions tensorflow_datasets/image/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""MNIST and Fashion MNIST."""
"""MNIST, Fashion MNIST, KMNIST and EMNIST."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import numpy as np
import six.moves.urllib as urllib
import tensorflow as tf

from tensorflow_datasets.core import api_utils
import tensorflow_datasets.public_api as tfds

# MNIST constants
Expand All @@ -36,7 +38,6 @@
_TRAIN_EXAMPLES = 60000
_TEST_EXAMPLES = 10000


_MNIST_CITATION = """\
@article{lecun2010mnist,
title={MNIST handwritten digit database},
Expand All @@ -47,7 +48,6 @@
}
"""


_FASHION_MNIST_CITATION = """\
@article{DBLP:journals/corr/abs-1708-07747,
author = {Han Xiao and
Expand All @@ -67,25 +67,25 @@
}
"""


_K_MNIST_CITATION = """\
@article{DBLP:journals/corr/abs-1812-01718,
author = {Tarin Clanuwat and
Mikel Bober{-}Irizar and
Asanobu Kitamoto and
Alex Lamb and
Kazuaki Yamamoto and
David Ha},
title = {Deep Learning for Classical Japanese Literature},
journal = {CoRR},
volume = {abs/1812.01718},
year = {2018},
url = {http://arxiv.org/abs/1812.01718},
archivePrefix = {arXiv},
eprint = {1812.01718},
timestamp = {Tue, 01 Jan 2019 15:01:25 +0100},
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1812-01718},
bibsource = {dblp computer science bibliography, https://dblp.org}
@online{clanuwat2018deep,
author = {Tarin Clanuwat and Mikel Bober-Irizar and Asanobu Kitamoto and Alex Lamb and Kazuaki Yamamoto and David Ha},
title = {Deep Learning for Classical Japanese Literature},
date = {2018-12-03},
year = {2018},
eprintclass = {cs.CV},
eprinttype = {arXiv},
eprint = {cs.CV/1812.01718},
}
"""

_EMNIST_CITATION = """\
@article{cohen_afshar_tapson_schaik_2017,
title={EMNIST: Extending MNIST to handwritten letters},
DOI={10.1109/ijcnn.2017.7966217},
journal={2017 International Joint Conference on Neural Networks (IJCNN)},
author={Cohen, Gregory and Afshar, Saeed and Tapson, Jonathan and Schaik, Andre Van},
year={2017}
}
"""

Expand Down Expand Up @@ -118,9 +118,8 @@ def _split_generators(self, dl_manager):
"test_data": _MNIST_TEST_DATA_FILENAME,
"test_labels": _MNIST_TEST_LABELS_FILENAME,
}
mnist_files = dl_manager.download_and_extract({
k: urllib.parse.urljoin(self.URL, v) for k, v in filenames.items()
})
mnist_files = dl_manager.download_and_extract(
{k: urllib.parse.urljoin(self.URL, v) for k, v in filenames.items()})

# MNIST provides TRAIN and TEST splits, not a VALIDATION split, so we only
# write the TRAIN and TEST splits to disk.
Expand Down Expand Up @@ -181,11 +180,13 @@ def _info(self):
"grayscale image, associated with a label from 10 "
"classes."),
features=tfds.features.FeaturesDict({
"image": tfds.features.Image(shape=_MNIST_IMAGE_SHAPE),
"label": tfds.features.ClassLabel(names=[
"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
]),
"image":
tfds.features.Image(shape=_MNIST_IMAGE_SHAPE),
"label":
tfds.features.ClassLabel(names=[
"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
]),
}),
supervised_keys=("image", "label"),
urls=["https://github.com/zalandoresearch/fashion-mnist"],
Expand All @@ -208,17 +209,182 @@ def _info(self):
"character to represent each of the 10 rows of Hiragana "
"when creating Kuzushiji-MNIST."),
features=tfds.features.FeaturesDict({
"image": tfds.features.Image(shape=_MNIST_IMAGE_SHAPE),
"label": tfds.features.ClassLabel(names=[
"o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"
]),
"image":
tfds.features.Image(shape=_MNIST_IMAGE_SHAPE),
"label":
tfds.features.ClassLabel(names=[
"o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"
]),
}),
supervised_keys=("image", "label"),
urls=["http://codh.rois.ac.jp/kmnist/index.html.en"],
citation=_K_MNIST_CITATION,
)


class EMNISTConfig(tfds.core.BuilderConfig):
"""BuilderConfig for EMNIST CONFIG."""

@api_utils.disallow_positional_args
def __init__(self, class_number, train_examples, test_examples, **kwargs):
"""BuilderConfig for EMNIST class number.
Args:
class_number: There are six different splits provided in this dataset. And
have different class numbers.
train_examples: number of train examples
test_examples: number of test examples
**kwargs: keyword arguments forwarded to super.
"""
super(EMNISTConfig, self).__init__(**kwargs)
self.class_number = class_number
self.train_examples = train_examples
self.test_examples = test_examples


class EMNIST(MNIST):
"""Emnist dataset."""

VERSION = tfds.core.Version("1.0.1")

BUILDER_CONFIGS = [
EMNISTConfig(
name="byclass",
class_number=62,
train_examples=697932,
test_examples=116323,
description="EMNIST ByClass: 814,255 characters. 62 unbalanced classes.",
version="1.0.1",
),
EMNISTConfig(
name="bymerge",
class_number=47,
train_examples=697932,
test_examples=116323,
description="EMNIST ByMerge: 814,255 characters. 47 unbalanced classes.",
version="1.0.1",
),
EMNISTConfig(
name="balanced",
class_number=47,
train_examples=112800,
test_examples=18800,
description="EMNIST Balanced: 131,600 characters. 47 balanced classes.",
version="1.0.1",
),
EMNISTConfig(
name="letters",
class_number=37,
train_examples=88800,
test_examples=14800,
description="EMNIST Letters: 103,600 characters. 26 balanced classes.",
version="1.0.1",
),
EMNISTConfig(
name="digits",
class_number=10,
train_examples=240000,
test_examples=40000,
description="EMNIST Digits: 280,000 characters. 10 balanced classes.",
version="1.0.1",
),
EMNISTConfig(
name="mnist",
class_number=10,
train_examples=60000,
test_examples=10000,
description="EMNIST MNIST: 70,000 characters. 10 balanced classes.",
version="1.0.1",
),
EMNISTConfig(
name="test",
class_number=62,
train_examples=10,
test_examples=2,
description="EMNIST test data config.",
version="1.0.1",
),
]

def _info(self):
return tfds.core.DatasetInfo(
builder=self,
description=(
"The EMNIST dataset is a set of handwritten character digits"
"derived from the NIST Special Database 19 and converted to"
"a 28x28 pixel image format and dataset structure that directly"
"matches the MNIST dataset."),
features=tfds.features.FeaturesDict({
"image":
tfds.features.Image(shape=_MNIST_IMAGE_SHAPE),
"label":
tfds.features.ClassLabel(
num_classes=self.builder_config.class_number),
}),
supervised_keys=("image", "label"),
urls=["https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip"],
citation=_EMNIST_CITATION,
)

def _split_generators(self, dl_manager):

filenames = {
"train_data":
"emnist-{}-train-images-idx3-ubyte".format(
self.builder_config.name),
"train_labels":
"emnist-{}-train-labels-idx1-ubyte".format(
self.builder_config.name),
"test_data":
"emnist-{}-test-images-idx3-ubyte".format(self.builder_config.name),
"test_labels":
"emnist-{}-test-labels-idx1-ubyte".format(self.builder_config.name),
}

dir_name = dl_manager.manual_dir

if not tf.io.gfile.exists(os.path.join(dir_name, filenames["train_data"])):
# The current tfds.core.download_manager is unable to
# extract multiple and nested files.
# We'll add soon! (Issue 234)
msg = ("You must download and extract the dataset files manually and "
"place them in : ")
msg += dl_manager.manual_dir
msg += """File tree must be like this :\n
.
| -- emnist
| |-- emnist-byclass-train-images-idx3-ubyte
| |-- emnist-byclass-train-labels-idx3-ubyte
| |-- emnist-byclass-test-images-idx3-ubyte
| |-- emnist-byclass-test-labels-idx3-ubyte
| |-- emnist-bymerge-train-images-idx3-ubyte
| |-- emnist-bymerge-train-labels-idx3-ubyte
| |-- emnist-bymerge-test-images-idx3-ubyte
| |-- emnist-bymerge-test-labels-idx3-ubyte
| |-- .......
"""
raise Exception(msg.replace(" ", ""))

return [
tfds.core.SplitGenerator(
name=tfds.Split.TRAIN,
num_shards=10,
gen_kwargs=dict(
num_examples=self.builder_config.train_examples,
data_path=os.path.join(dir_name, filenames["train_data"]),
label_path=os.path.join(dir_name, filenames["train_labels"]),
)),
tfds.core.SplitGenerator(
name=tfds.Split.TEST,
num_shards=1,
gen_kwargs=dict(
num_examples=self.builder_config.test_examples,
data_path=os.path.join(dir_name, filenames["test_data"]),
label_path=os.path.join(dir_name, filenames["test_labels"]),
))
]


def _extract_mnist_images(image_filepath, num_images):
with tf.io.gfile.GFile(image_filepath, "rb") as f:
f.read(16) # header
Expand Down
5 changes: 5 additions & 0 deletions tensorflow_datasets/image/mnist_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,10 @@ class KMNISTTest(MNISTTest):
DATASET_CLASS = mnist.KMNIST


class EMNISTTest(MNISTTest):
DATASET_CLASS = mnist.EMNIST
BUILDER_CONFIG_NAMES_TO_TEST = ["test"]


if __name__ == "__main__":
testing.test_main()
2 changes: 1 addition & 1 deletion tensorflow_datasets/testing/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def write_label_file(filename, num_labels):


def main(_):
for mnist in ["mnist", "fashion_mnist", "kmnist"]:
for mnist in ["mnist", "fashion_mnist", "kmnist", "emnist"]:
output_dir = mnist_dir(mnist)
test_utils.remake_dir(output_dir)
write_image_file(os.path.join(output_dir, _TRAIN_DATA_FILENAME), 10)
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
11111111
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
11111111

0 comments on commit a266ec3

Please sign in to comment.