Skip to content

Commit

Permalink
init: commit
Browse files Browse the repository at this point in the history
  • Loading branch information
kkweon committed Apr 21, 2017
0 parents commit d05555c
Show file tree
Hide file tree
Showing 15 changed files with 831 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
model/*.h5 filter=lfs diff=lfs merge=lfs -text
model/resnet.h5 filter=lfs diff=lfs merge=lfs -text
model/vggnet5.h5 filter=lfs diff=lfs merge=lfs -text
model/vggnet.h5 filter=lfs diff=lfs merge=lfs -text
128 changes: 128 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
###DATA###
MNIST/

###Python###

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

# C extensions
*.so

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.cache
nosetests.xml
coverage.xml

# Translations
*.mo
*.pot

# Django stuff:
*.log

# Sphinx documentation
docs/_build/

# PyBuilder
target/


###PyCharm###

# PyCharm
# http://www.jetbrains.com/pycharm/webhelp/project.html
.idea
.iml


###IPythonNotebook###

# Temporary data
.ipynb_checkpoints/


###OSX###

.DS_Store
.AppleDouble
.LSOverride

# Icon must end with two \r
Icon


# Thumbnails
._*

# Files that might appear on external disk
.Spotlight-V100
.Trashes

# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk


###Linux###

*~

# KDE directory preferences
.directory


###Windows###

# Windows image file caches
Thumbs.db
ehthumbs.db

# Folder config file
Desktop.ini

# Recycle Bin used on file shares
$RECYCLE.BIN/

# Windows Installer files
*.cab
*.msi
*.msm
*.msp

# Windows shortcuts
*.lnk
71 changes: 71 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
[resnet]: images/resnet.png "ResNet Model"
[vggnet]: images/vggnet.png "VggNet Model"
[vggnet5]: images/vggnet5.png "VggNet5 Model"

# MNIST Competition Tensorflow KR Group

* MNIST competition submission files
* Used Keras
* [Model Architecture](#model-architectures)
* [ResNet](#resnet)
* [VggNet](#vggnet)
* [VggNet5](#vggnet5)

## Performance

| **Model** | **Description** | **Accuracy** |
|:-----------:|:-----------------------------------------------:|:------------:|
| VGG-like | VGGNet-like but smaller | 99.71% |
| Resnet-like | ResNet-like but smaller | 99.60% |
| VGG-like | VGGNet-like but even smaller than the first one | 99.63% |
| **Final** | **Ensemble 3 models + Voting** | **99.80%** |

## Run

#### Evaluation
```bash
python evaluation.py
```

#### Train
```bash
python resnet.py 10 # 10 epochs & resnet
python vgg16.py 10 # 10 epochs & vgg
python vgg5.py 10 # 10 epochs & vgg
```

## File descriptions
```bash
├── evaluation.py # evaluation.py
├── images # model architectures
│   ├── resnet.png
│   ├── vggnet5.png
│   └── vggnet.png
├── MNIST # mnist data (not included in this repo)
│   ├── t10k-images-idx3-ubyte.gz
│   ├── t10k-labels-idx1-ubyte.gz
│   ├── train-images-idx3-ubyte.gz
│   └── train-labels-idx1-ubyte.gz
├── model # model weights
│   ├── resnet.h5
│   ├── vggnet5.h5
│   └── vggnet.h5
├── model.py # base model interface
├── README.md
├── utils.py # helper functions
├── resnet.py
├── vgg16.py
└── vgg5.py
```


## Model Architectures

#### ResNet
![resnet][resnet]

#### VggNet
![vggnet][vggnet]

#### VggNet5
![vggnet5][vggnet5]
66 changes: 66 additions & 0 deletions evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import numpy as np

from vgg16 import VGGNet
from vgg5 import VGGNet5
from resnet import ResNet50 as ResNet
from utils import load_mnist


def load_models():
"""Load models """
model_list = []

model_list.append(VGGNet("model/vggnet.h5"))
model_list.append(ResNet("model/resnet.h5"))
model_list.append(VGGNet5("model/vggnet5.h5"))

return model_list


def evaluate(prediction, true_labels):
"""Return an accuracy
Parameters
----------
prediction : 2-d array, shape (n_sample, n_classes)
Onehot encoded predicted array
true_labels : 2-d array, shape (n_sample, n_classes)
Onehot encoded true array
Returns
----------
accuracy : float
Return an accuracy
"""
pred = np.argmax(prediction, 1)
true = np.argmax(true_labels, 1)

equal = (pred == true)

return np.mean(equal)


def main():
model_list = load_models()
_, _, (X_test, y_test) = load_mnist()

pred_list = []

for idx, model in enumerate(model_list):
pred = model.predict(X_test)
pred_list.append(pred)

# Check a single model accuracy
acc = evaluate(pred, y_test)
print(f"Model-{idx}: {acc:>.5%}")

pred_list = np.asarray(pred_list)
pred_mean = np.mean(pred_list, 0)

accuracy = evaluate(pred_mean, y_test)
print(f"Final Test Accuracy: {accuracy:>.5%}")


if __name__ == '__main__':
main()
Binary file added images/resnet.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/vggnet.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/vggnet5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit d05555c

Please sign in to comment.