Skip to content

Commit

Permalink
Fashion mnist dataset (keras-team#7809)
Browse files Browse the repository at this point in the history
* fixed typo

* added fashion-mnist dataset

* added docs

* pep8

* grammer

* use offset instead of struct

* reshape as in docs
  • Loading branch information
kashif authored and fchollet committed Sep 6, 2017
1 parent 5625d70 commit a379b42
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 0 deletions.
33 changes: 33 additions & 0 deletions docs/templates/datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,39 @@ from keras.datasets import mnist
- __path__: if you do not have the index file locally (at `'~/.keras/datasets/' + path`), it will be downloaded to this location.


---

## Fashion-MNIST database of fashion articles

Dataset of 60,000 28x28 grayscale images of 10 fashion categories, along with a test set of 10,000 images. This dataset can be used as a drop-in replacement for MNIST. The class labels are:

| Label | Description |
| --- | --- |
| 0 | T-shirt/top |
| 1 | Trouser |
| 2 | Pullover |
| 3 | Dress |
| 4 | Coat |
| 5 | Sandal |
| 6 | Shirt |
| 7 | Sneaker |
| 8 | Bag |
| 9 | Ankle boot |

### Usage:

```python
from keras.datasets import fashion_mnist

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
```

- __Returns:__
- 2 tuples:
- __x_train, x_test__: uint8 array of grayscale image data with shape (num_samples, 28, 28).
- __y_train, y_test__: uint8 array of labels (integers in range 0-9) with shape (num_samples,).


---

## Boston housing price regression dataset
Expand Down
1 change: 1 addition & 0 deletions keras/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from . import cifar10
from . import cifar100
from . import boston_housing
from . import fashion_mnist
37 changes: 37 additions & 0 deletions keras/datasets/fashion_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import gzip
import os

from ..utils.data_utils import get_file
import numpy as np


def load_data():
"""Loads the Fashion-MNIST dataset.
# Returns
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
"""
dirname = os.path.join('datasets', 'fashion-mnist')
base = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
files = ['train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz']

paths = []
for file in files:
paths.append(get_file(file, origin=base + file, cache_subdir=dirname))

with gzip.open(paths[0], 'rb') as lbpath:
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

with gzip.open(paths[1], 'rb') as imgpath:
x_train = np.frombuffer(imgpath.read(), np.uint8,
offset=16).reshape(len(y_train), 28, 28)

with gzip.open(paths[2], 'rb') as lbpath:
y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)

with gzip.open(paths[3], 'rb') as imgpath:
x_test = np.frombuffer(imgpath.read(), np.uint8,
offset=16).reshape(len(y_test), 28, 28)

return (x_train, y_train), (x_test, y_test)
11 changes: 11 additions & 0 deletions tests/keras/datasets/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.datasets import imdb
from keras.datasets import mnist
from keras.datasets import boston_housing
from keras.datasets import fashion_mnist


def test_cifar():
Expand Down Expand Up @@ -75,5 +76,15 @@ def test_boston_housing():
assert len(x_test) == len(y_test)


def test_fashion_mnist():
# only run data download tests 20% of the time
# to speed up frequent testing
random.seed(time.time())
if random.random() > 0.8:
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
assert len(x_train) == len(y_train) == 60000
assert len(x_test) == len(y_test) == 10000


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit a379b42

Please sign in to comment.