This repository has been archived by the owner on Dec 1, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 57
/
create_model.py
58 lines (46 loc) · 2.15 KB
/
create_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from sklearn.ensemble import RandomForestClassifier
from sklearn import datasets, metrics
from sklearn.utils import shuffle
from sklearn.datasets import fetch_mldata
from sklearn.externals import joblib
from six.moves import urllib
if __name__ == '__main__':
try:
mnist = fetch_mldata('MNIST original')
except:
print("Could not download MNIST data from mldata.org, trying alternative...")
# Alternative method to load MNIST, if mldata.org is down
from scipy.io import loadmat
mnist_alternative_url = "https://github.com/amplab/datascience-sp14/raw/master/lab7/mldata/mnist-original.mat"
mnist_path = "./mnist-original.mat"
response = urllib.request.urlopen(mnist_alternative_url)
with open(mnist_path, "wb") as f:
content = response.read()
f.write(content)
mnist_raw = loadmat(mnist_path)
mnist = {
"data": mnist_raw["data"].T,
"target": mnist_raw["label"][0],
"COL_NAMES": ["label", "data"],
"DESCR": "mldata.org dataset: mnist-original",
}
print("Success!")
#mnist = fetch_mldata('MNIST original', data_home="./mnist_sklearn")
# To apply a classifier on this data, we need to flatten the image, to
# turn the data in a (samples, feature) matrix:
n_samples = len(mnist['data'])
data = mnist['data'].reshape((n_samples, -1))
targets = mnist['target']
data,targets = shuffle(data,targets)
classifier = RandomForestClassifier(n_estimators=30)
# We learn the digits on the first half of the digits
classifier.fit(data[:n_samples // 2], targets[:n_samples // 2])
# Now predict the value of the digit on the second half:
expected = targets[n_samples // 2:]
test_data = data[n_samples // 2:]
print(classifier.score(test_data, expected))
predicted = classifier.predict(data[n_samples // 2:])
print("Classification report for classifier %s:\n%s\n"
% (classifier, metrics.classification_report(expected, predicted)))
print("Confusion matrix:\n%s" % metrics.confusion_matrix(expected, predicted))
joblib.dump(classifier, '/data/sk.pkl')