forked from microsoft/CNTK
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SimpleMNIST.py
126 lines (98 loc) · 5.01 KB
/
SimpleMNIST.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
import numpy as np
import sys
import os
from cntk import Trainer
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP
from cntk.device import cpu, set_default_device
from cntk.learner import sgd, learning_rate_schedule, UnitType
from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error, relu, element_times, constant
abs_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(abs_path, "..", "..", "..", "..", "..", "Examples", "common"))
from nn import fully_connected_classifier_net, print_training_progress
def check_path(path):
if not os.path.exists(path):
readme_file = os.path.normpath(os.path.join(
os.path.dirname(path), "..", "README.md"))
raise RuntimeError(
"File '%s' does not exist. Please follow the instructions at %s to download and prepare it." % (path, readme_file))
def create_reader(path, is_training, input_dim, label_dim):
return MinibatchSource(CTFDeserializer(path, StreamDefs(
features = StreamDef(field='features', shape=input_dim, is_sparse=False),
labels = StreamDef(field='labels', shape=label_dim, is_sparse=False)
)), randomize=is_training, epoch_size = INFINITELY_REPEAT if is_training else FULL_DATA_SWEEP)
# Creates and trains a feedforward classification model for MNIST images
def simple_mnist(debug_output=False):
input_dim = 784
num_output_classes = 10
num_hidden_layers = 1
hidden_layers_dim = 200
# Input variables denoting the features and label data
input = input_variable(input_dim, np.float32)
label = input_variable(num_output_classes, np.float32)
# Instantiate the feedforward classification model
scaled_input = element_times(constant(0.00390625), input)
z = fully_connected_classifier_net(
scaled_input, num_output_classes, hidden_layers_dim, num_hidden_layers, relu)
ce = cross_entropy_with_softmax(z, label)
pe = classification_error(z, label)
try:
rel_path = os.path.join(os.environ['CNTK_EXTERNAL_TESTDATA_SOURCE_DIRECTORY'],
*"Image/MNIST/v0/Train-28x28_cntk_text.txt".split("/"))
except KeyError:
rel_path = os.path.join(abs_path, "..", "..", "..", "..", "..", "Examples", "Image", "DataSets", "MNIST", "Train-28x28_cntk_text.txt")
path = os.path.normpath(os.path.join(abs_path, rel_path))
check_path(path)
reader_train = create_reader(path, True, input_dim, num_output_classes)
input_map = {
input : reader_train.streams.features,
label : reader_train.streams.labels
}
lr_per_minibatch=learning_rate_schedule(0.2, UnitType.minibatch)
# Instantiate the trainer object to drive the model training
trainer = Trainer(z, ce, pe, sgd(z.parameters, lr=lr_per_minibatch))
# Get minibatches of images to train with and perform model training
minibatch_size = 64
num_samples_per_sweep = 60000
num_sweeps_to_train_with = 10
num_minibatches_to_train = (num_samples_per_sweep * num_sweeps_to_train_with) / minibatch_size
training_progress_output_freq = 500
if debug_output:
training_progress_output_freq = training_progress_output_freq/4
for i in range(0, int(num_minibatches_to_train)):
mb = reader_train.next_minibatch(minibatch_size, input_map=input_map)
trainer.train_minibatch(mb)
print_training_progress(trainer, i, training_progress_output_freq)
# Load test data
try:
rel_path = os.path.join(os.environ['CNTK_EXTERNAL_TESTDATA_SOURCE_DIRECTORY'],
*"Image/MNIST/v0/Test-28x28_cntk_text.txt".split("/"))
except KeyError:
rel_path = os.path.join(*"../../../../../Examples/Image/DataSets/MNIST/Test-28x28_cntk_text.txt".split("/"))
path = os.path.normpath(os.path.join(abs_path, rel_path))
check_path(path)
reader_test = create_reader(path, False, input_dim, num_output_classes)
input_map = {
input : reader_test.streams.features,
label : reader_test.streams.labels
}
# Test data for trained model
test_minibatch_size = 1024
num_samples = 10000
num_minibatches_to_test = num_samples / test_minibatch_size
test_result = 0.0
for i in range(0, int(num_minibatches_to_test)):
mb = reader_test.next_minibatch(test_minibatch_size, input_map=input_map)
eval_error = trainer.test_minibatch(mb)
test_result = test_result + eval_error
# Average of evaluation errors of all test minibatches
return test_result / num_minibatches_to_test
if __name__=='__main__':
# Specify the target device to be used for computing, if you do not want to
# use the best available one, e.g.
# set_default_device(cpu())
error = simple_mnist()
print("Error: %f" % error)