Skip to content

Commit a43df1e

Browse files
committed
[add] Neural network classifier and results metrics
1 parent dd74d31 commit a43df1e

20 files changed

+37391
-12409
lines changed

DAP/platt_params_SVM

Whitespace-only changes.

DAP/prediction_NN

Lines changed: 6180 additions & 0 deletions
Large diffs are not rendered by default.

DAP/prediction_NN2

Lines changed: 6180 additions & 0 deletions
Large diffs are not rendered by default.

DAP/prediction renamed to DAP/prediction_SVM

Lines changed: 6163 additions & 6163 deletions
Large diffs are not rendered by default.

DAP/probabilities

Lines changed: 0 additions & 6180 deletions
This file was deleted.

DAP/probabilities_NN

Lines changed: 6180 additions & 0 deletions
Large diffs are not rendered by default.

DAP/probabilities_NN2

Lines changed: 6180 additions & 0 deletions
Large diffs are not rendered by default.

DAP/probabilities_SVM

Lines changed: 6180 additions & 0 deletions
Large diffs are not rendered by default.

DAP_eval.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
import numpy as np
1010
import matplotlib.pyplot as plt
1111
from sklearn.metrics import roc_curve, auc
12+
from utils import bzPickle, bzUnpickle, get_class_attributes, create_data
13+
import warnings
14+
warnings.filterwarnings('ignore')
1215

1316
def nameonly(x):
1417
return x.split('\t')[1]
@@ -29,9 +32,8 @@ def loaddict(filename,converter=str):
2932
classnames = loadstr('classes.txt',nameonly)
3033
numexamples = loaddict('numexamples.txt',int)
3134

32-
def evaluate(split,C):
35+
def evaluate(split,C, attributepattern):
3336
global test_classnames
34-
attributepattern = 'DAP/probabilities'
3537

3638
if split == 0:
3739
test_classnames=loadstr('testclasses.txt')
@@ -81,7 +83,7 @@ def evaluate(split,C):
8183
return confusion,np.asarray(prob),L
8284

8385

84-
def plot_confusion(confusion):
86+
def plot_confusion(confusion, clf):
8587
fig=plt.figure(figsize=(10,9))
8688
plt.imshow(confusion,interpolation='nearest',origin='upper')
8789
plt.clim(0,1)
@@ -96,10 +98,10 @@ def plot_confusion(confusion):
9698
fig.subplots_adjust(bottom=0.22)
9799
plt.gray()
98100
plt.colorbar(shrink=0.79)
99-
plt.savefig('results/AwA-ROC-confusion-DAP.pdf')
101+
plt.savefig('results/AwA-ROC-confusion-DAP-%s.pdf' %clf)
100102
return
101103

102-
def plot_roc(P,GT):
104+
def plot_roc(P,GT, clf):
103105
AUC=[]
104106
CURVE=[]
105107
for i,c in enumerate(test_classnames):
@@ -109,6 +111,10 @@ def plot_roc(P,GT):
109111
print ("AUC: %s %5.3f" % (c,roc_auc))
110112
AUC.append(roc_auc)
111113
CURVE.append(np.array([fp,tp]))
114+
115+
print ("----------------------------------")
116+
print ("Mean classAUC %g" % (np.mean(AUC)*100))
117+
112118
order = np.argsort(AUC)[::-1]
113119
styles=['-','-','-','-','-','-','-','--','--','--']
114120
plt.figure(figsize=(9,5))
@@ -121,23 +127,61 @@ def plot_roc(P,GT):
121127
plt.yticks([0.0,0.2,0.4,0.6,0.8,1.0], [r'$0$', r'$0.2$',r'$0.4$',r'$0.6$',r'$0.8$',r'$1.0$'],fontsize=18)
122128
plt.xlabel('false negative rate',fontsize=18)
123129
plt.ylabel('true positive rate',fontsize=18)
124-
plt.savefig('results/AwA-ROC-DAP.pdf')
130+
plt.savefig('results/AwA-ROC-DAP-%s.pdf' %clf)
131+
132+
133+
def plot_attAUC(GT, attributepattern, clf):
134+
AUC=[]
135+
P = np.loadtxt(attributepattern)
136+
137+
# Loading ground truth
138+
test_index = bzUnpickle('./CreatedData/test_features_index.txt')
139+
test_attributes = get_class_attributes('./', name='test')
140+
_, y_true = create_data('./CreatedData/test_featuresVGG19.pic.bz2',test_index, test_attributes)
141+
142+
for i in range(y_true.shape[1]):
143+
fp, tp, _ = roc_curve(y_true[:,i], P[:,i])
144+
roc_auc = auc(fp, tp)
145+
AUC.append(roc_auc)
146+
print ("Mean attrAUC %g" % (np.nanmean(AUC)*100) )
147+
148+
xs = np.arange(y_true.shape[1])
149+
width = 2
150+
plt.figure(figsize=(9,5))
151+
plt.bar(xs, AUC, width, align='center')
152+
plt.xticks(xs) #Replace default x-ticks with xs, then replace xs with labels
153+
plt.yticks(AUC)
154+
plt.ylabel('Percent AUC',fontsize=18)
155+
plt.savefig('results/AwA-AttAUC-DAP-%s.pdf' %clf)
156+
125157

126158
def main():
159+
list_clf = ['SVM', 'NN']
160+
try:
161+
clf = str(sys.argv[1])
162+
except IndexError:
163+
clf = 'SVM'
164+
165+
if clf not in list_clf:
166+
print ("Non valid choice of classifier (SVM, NN)")
167+
raise SystemExit
168+
127169
try:
128-
split = int(sys.argv[1])
170+
split = int(sys.argv[2])
129171
except IndexError:
130172
split = 0
131173

132174
try:
133-
C = float(sys.argv[2])
175+
C = float(sys.argv[3])
134176
except IndexError:
135177
C = 10.
136178

137-
confusion,prob,L = evaluate(split,C)
179+
attributepattern = 'DAP/probabilities_' + clf
180+
confusion,prob,L = evaluate(split,C, attributepattern)
181+
plot_confusion(confusion, clf)
182+
plot_roc(prob,L, clf)
183+
plot_attAUC(L, attributepattern, clf)
138184
print ("Mean class accuracy %g" % np.mean(np.diag(confusion)*100))
139-
plot_confusion(confusion)
140-
plot_roc(prob,L)
141185

142186
if __name__ == '__main__':
143187
main()

DirectAttributePrediction.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from utils import bzPickle, bzUnpickle, get_class_attributes, create_data
55
from sklearn.model_selection import train_test_split
66
from SVMClassifier import SVMClassifier
7+
from NeuralNetworkClassifier import NeuralNetworkClassifier2
78

89

910
def DirectAttributePrediction(classifier='SVM',):
@@ -34,30 +35,47 @@ def DirectAttributePrediction(classifier='SVM',):
3435
print ('X_test to dense...')
3536
X_test = X_test.toarray()
3637

37-
# Training svm
38-
platt_params = []
39-
for i in range(N_ATTRIBUTES):
40-
print ('--------- Attribute %d/%d ---------' % (i+1,N_ATTRIBUTES))
41-
t0 = time()
38+
# CHOOSING SVM
39+
if classifier == 'SVM':
40+
platt_params = []
41+
for i in range(N_ATTRIBUTES):
42+
print ('--------- Attribute %d/%d ---------' % (i+1,N_ATTRIBUTES))
43+
t0 = time()
4244

43-
# Choose classifier
44-
if classifier == 'SVM':
45+
# SVM classifier
4546
clf = SVMClassifier()
4647

47-
# Training
48-
clf.fit(Xplat_train, yplat_train[:,i])
49-
print ('Fitted classifier in: %fs' % (time() - t0))
50-
clf.set_platt_params(Xplat_val, yplat_val[:,i])
51-
52-
# Predicting
53-
print ('Predicting for attribute %d...' % (i+1))
54-
y_pred[:,i] = clf.predict(X_test)
55-
y_proba[:,i] = clf.predict_proba(X_test)
56-
57-
print ('Saving files...')
58-
np.savetxt('./DAP/platt_params', platt_params)
59-
np.savetxt('./DAP/prediction', y_pred)
60-
np.savetxt('./DAP/probabilities', y_proba)
48+
# Training
49+
clf.fit(Xplat_train, yplat_train[:,i])
50+
print ('Fitted classifier in: %fs' % (time() - t0))
51+
clf.set_platt_params(Xplat_val, yplat_val[:,i])
52+
53+
# Predicting
54+
print ('Predicting for attribute %d...' % (i+1))
55+
y_pred[:,i] = clf.predict(X_test)
56+
y_proba[:,i] = clf.predict_proba(X_test)
57+
58+
print ('Saving files...')
59+
np.savetxt('./DAP/platt_params_SVM', platt_params)
60+
np.savetxt('./DAP/prediction_SVM', y_pred)
61+
np.savetxt('./DAP/probabilities_SVM', y_proba)
62+
63+
64+
# CHOOSING NEURAL NETWORK
65+
if classifier == 'NN':
66+
clf = NeuralNetworkClassifier2(dim_features=X_train.shape[1], nb_attributes=N_ATTRIBUTES)
67+
68+
print ('Fitting Neural Network...')
69+
clf.fit(X_train, y_train)
70+
71+
print ('Predicting attributes...')
72+
y_pred = np.array(clf.predict(X_test))
73+
y_pred = y_pred.reshape((y_pred.shape[0], y_pred.shape[1])).T
74+
y_proba = y_pred
75+
76+
print ('Saving files...')
77+
np.savetxt('./DAP/prediction_NN', y_pred)
78+
np.savetxt('./DAP/probabilities_NN', y_proba)
6179

6280

6381
def main():

0 commit comments

Comments
 (0)