|
| 1 | +import pandas as pd |
| 2 | +import numpy as np |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +import sys |
| 5 | + |
| 6 | +# easier to work with as pandas dataframes because we can filter classes |
| 7 | +Xtest = pd.read_csv("mnist_csv/Xtest.txt", header=None) |
| 8 | +Xtrain = pd.read_csv("mnist_csv/Xtrain.txt", header=None) |
| 9 | +Ytest = pd.read_csv("mnist_csv/label_test.txt", header=None) |
| 10 | +Ytrain = pd.read_csv("mnist_csv/label_train.txt", header=None) |
| 11 | + |
| 12 | +class Bayes(object): |
| 13 | + def fit(self, X, y): |
| 14 | + self.gaussians = dict() |
| 15 | + labels = set(y.as_matrix().flatten()) |
| 16 | + for c in labels: |
| 17 | + current_x = Xtrain[Ytrain[0] == c] |
| 18 | + self.gaussians[c] = { |
| 19 | + 'mu': current_x.mean(), |
| 20 | + 'sigma': np.cov(current_x.T), |
| 21 | + } |
| 22 | + # plt.imshow(self.gaussians[c]['sigma']) |
| 23 | + # plt.show() |
| 24 | + |
| 25 | + def predict_one(self, x): |
| 26 | + lls = self.distributions(x) |
| 27 | + return np.argmax(lls) |
| 28 | + |
| 29 | + def predict(self, X): |
| 30 | + Ypred = X.apply(lambda x: self.predict_one(x), axis=1) |
| 31 | + return Ypred |
| 32 | + |
| 33 | + def distributions(self, x): |
| 34 | + lls = np.zeros(len(self.gaussians)) |
| 35 | + for c,g in self.gaussians.iteritems(): |
| 36 | + x_minus_mu = x - g['mu'] |
| 37 | + k1 = np.log(2*np.pi)*x.shape + np.log(np.linalg.det(g['sigma'])) |
| 38 | + k2 = np.dot( np.dot(x_minus_mu, np.linalg.inv(g['sigma'])), x_minus_mu) |
| 39 | + ll = -0.5*(k1 + k2) |
| 40 | + lls[c] = ll |
| 41 | + return lls |
| 42 | + |
| 43 | + |
| 44 | +if __name__ == '__main__': |
| 45 | + bayes = Bayes() |
| 46 | + bayes.fit(Xtrain, Ytrain) |
| 47 | + Ypred = bayes.predict(Xtest) |
| 48 | + C = np.zeros((10,10), dtype=np.int) |
| 49 | + # print len(Ypred), len(Ytest) |
| 50 | + for p,t in zip(Ypred.as_matrix().flatten(), Ytest.as_matrix().flatten()): |
| 51 | + C[t,p] += 1 |
| 52 | + print "Confusion matrix:" |
| 53 | + print C |
| 54 | + print "Accuracy:", np.trace(C) / 500.0 |
| 55 | + |
| 56 | + if len(sys.argv) > 1 and sys.argv[1] == 'reconstruct': |
| 57 | + # show means as images |
| 58 | + Q = pd.read_csv("mnist_csv/Q.txt", header=None).as_matrix() |
| 59 | + for c,g in bayes.gaussians.iteritems(): |
| 60 | + y = np.dot(Q, g['mu'].as_matrix()) |
| 61 | + y = np.reshape(y, (28,28)) |
| 62 | + plt.imshow(y) |
| 63 | + plt.title(c) |
| 64 | + plt.show() |
| 65 | + |
| 66 | + # show distributions for 3 misclassified examples |
| 67 | + print "distributions for 3 misclassified examples:" |
| 68 | + count = 0 |
| 69 | + for i,p in Ypred.iteritems(): |
| 70 | + if p != Ytest.loc[i][0]: |
| 71 | + print "predicted:", p, "actual:", Ytest.loc[i][0] |
| 72 | + print bayes.distributions(Xtest.loc[i]) |
| 73 | + count += 1 |
| 74 | + if count >= 3: |
| 75 | + break |
0 commit comments