Skip to content

Commit 926f596

Browse files
add bayes classifier
1 parent 2c37509 commit 926f596

File tree

7 files changed

+1670
-1
lines changed

7 files changed

+1670
-1
lines changed

bayes.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import pandas as pd
2+
import numpy as np
3+
import matplotlib.pyplot as plt
4+
import sys
5+
6+
# easier to work with as pandas dataframes because we can filter classes
7+
Xtest = pd.read_csv("mnist_csv/Xtest.txt", header=None)
8+
Xtrain = pd.read_csv("mnist_csv/Xtrain.txt", header=None)
9+
Ytest = pd.read_csv("mnist_csv/label_test.txt", header=None)
10+
Ytrain = pd.read_csv("mnist_csv/label_train.txt", header=None)
11+
12+
class Bayes(object):
13+
def fit(self, X, y):
14+
self.gaussians = dict()
15+
labels = set(y.as_matrix().flatten())
16+
for c in labels:
17+
current_x = Xtrain[Ytrain[0] == c]
18+
self.gaussians[c] = {
19+
'mu': current_x.mean(),
20+
'sigma': np.cov(current_x.T),
21+
}
22+
# plt.imshow(self.gaussians[c]['sigma'])
23+
# plt.show()
24+
25+
def predict_one(self, x):
26+
lls = self.distributions(x)
27+
return np.argmax(lls)
28+
29+
def predict(self, X):
30+
Ypred = X.apply(lambda x: self.predict_one(x), axis=1)
31+
return Ypred
32+
33+
def distributions(self, x):
34+
lls = np.zeros(len(self.gaussians))
35+
for c,g in self.gaussians.iteritems():
36+
x_minus_mu = x - g['mu']
37+
k1 = np.log(2*np.pi)*x.shape + np.log(np.linalg.det(g['sigma']))
38+
k2 = np.dot( np.dot(x_minus_mu, np.linalg.inv(g['sigma'])), x_minus_mu)
39+
ll = -0.5*(k1 + k2)
40+
lls[c] = ll
41+
return lls
42+
43+
44+
if __name__ == '__main__':
45+
bayes = Bayes()
46+
bayes.fit(Xtrain, Ytrain)
47+
Ypred = bayes.predict(Xtest)
48+
C = np.zeros((10,10), dtype=np.int)
49+
# print len(Ypred), len(Ytest)
50+
for p,t in zip(Ypred.as_matrix().flatten(), Ytest.as_matrix().flatten()):
51+
C[t,p] += 1
52+
print "Confusion matrix:"
53+
print C
54+
print "Accuracy:", np.trace(C) / 500.0
55+
56+
if len(sys.argv) > 1 and sys.argv[1] == 'reconstruct':
57+
# show means as images
58+
Q = pd.read_csv("mnist_csv/Q.txt", header=None).as_matrix()
59+
for c,g in bayes.gaussians.iteritems():
60+
y = np.dot(Q, g['mu'].as_matrix())
61+
y = np.reshape(y, (28,28))
62+
plt.imshow(y)
63+
plt.title(c)
64+
plt.show()
65+
66+
# show distributions for 3 misclassified examples
67+
print "distributions for 3 misclassified examples:"
68+
count = 0
69+
for i,p in Ypred.iteritems():
70+
if p != Ytest.loc[i][0]:
71+
print "predicted:", p, "actual:", Ytest.loc[i][0]
72+
print bayes.distributions(Xtest.loc[i])
73+
count += 1
74+
if count >= 3:
75+
break

0 commit comments

Comments
 (0)