-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathplotroc_svm.py
219 lines (195 loc) · 6.16 KB
/
plotroc_svm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
#!/usr/bin/env python
#This tool allow users to plot SVM-prob ROC curve from data
from svmutil import *
from sys import argv, platform
from os import path, popen
from random import randrange , seed
from operator import itemgetter
from time import sleep
#search path for gnuplot executable
#be careful on using windows LONG filename, surround it with double quotes.
#and leading 'r' to make it raw string, otherwise, repeat \\.
gnuplot_exe_list = [r'"C:\Program Files\gnuplot\pgnuplot.exe"', "/usr/bin/gnuplot","/usr/local/bin/gnuplot"]
def get_pos_deci(train_y, train_x, test_y, test_x, param):
model = svm_train(train_y, train_x, param)
#predict and grab decision value, assure deci>0 for label+,
#the positive descision value = val[0]*labels[0]
labels = model.get_labels()
py, evals, deci = svm_predict(test_y, test_x, model)
deci = [labels[0]*val[0] for val in deci]
return deci,model
#get_cv_deci(prob_y[], prob_x[], svm_parameter param, nr_fold)
#input raw attributes, labels, param, cv_fold in decision value building
#output list of decision value, remember to seed(0)
def get_cv_deci(prob_y, prob_x, param, nr_fold):
if nr_fold == 1 or nr_fold==0:
deci,model = get_pos_deci(prob_y, prob_x, prob_y, prob_x, param)
return deci
deci, model = [], []
prob_l = len(prob_y)
#random permutation by swapping i and j instance
for i in range(prob_l):
j = randrange(i,prob_l)
prob_x[i], prob_x[j] = prob_x[j], prob_x[i]
prob_y[i], prob_y[j] = prob_y[j], prob_y[i]
#cross training : folding
for i in range(nr_fold):
begin = i * prob_l // nr_fold
end = (i + 1) * prob_l // nr_fold
train_x = prob_x[:begin] + prob_x[end:]
train_y = prob_y[:begin] + prob_y[end:]
test_x = prob_x[begin:end]
test_y = prob_y[begin:end]
subdeci, submdel = get_pos_deci(train_y, train_x, test_y, test_x, param)
deci += subdeci
return deci
#a simple gnuplot object
class gnuplot:
def __init__(self, term='onscreen'):
# -persists leave plot window on screen after gnuplot terminates
if platform == 'win32':
cmdline = gnuplot_exe
self.__dict__['screen_term'] = 'windows'
else:
cmdline = gnuplot_exe + ' -persist'
self.__dict__['screen_term'] = 'x11'
self.__dict__['iface'] = popen(cmdline,'w')
self.set_term(term)
def set_term(self, term):
if term=='onscreen':
self.writeln("set term %s" % self.screen_term)
else:
#term must be either x.ps or x.png
if term.find('.ps')>0:
self.writeln("set term postscript eps color 22")
elif term.find('.png')>0:
self.writeln("set term png")
else:
print("You must set term to either *.ps or *.png")
raise SystemExit
self.output = term
def writeln(self,cmdline):
self.iface.write(cmdline + '\n')
def __setattr__(self, attr, val):
if type(val) == str:
self.writeln('set %s \"%s\"' % (attr, val))
else:
print("Unsupport format:", attr, val)
raise SystemExit
#terminate gnuplot
def __del__(self):
self.writeln("quit")
self.iface.flush()
self.iface.close()
def __repr__(self):
return "<gnuplot instance: output=%s>" % term
#data is a list of [x,y]
def plotline(self, data):
self.writeln("plot \"-\" notitle with lines linewidth 1")
for i in range(len(data)):
self.writeln("%f %f" % (data[i][0], data[i][1]))
sleep(0) #delay
self.writeln("e")
if platform=='win32':
sleep(3)
#processing argv and set some global variables
def proc_argv(argv = argv):
#print("Usage: %s " % argv[0])
#The command line : ./plotroc.py [-v cv_fold | -T testing_file] [libsvm-options] training_file
train_file = argv[-1]
test_file = None
fold = 5
options = []
i = 1
while i < len(argv)-1:
if argv[i] == '-T':
test_file = argv[i+1]
i += 1
elif argv[i] == '-v':
fold = int(argv[i+1])
i += 1
else :
options += [argv[i]]
i += 1
return ' '.join(options), fold, train_file, test_file
def plot_roc(deci, label, output, title):
#count of postive and negative labels
db = []
pos, neg = 0, 0
for i in range(len(label)):
if label[i]>0:
pos+=1
else:
neg+=1
db.append([deci[i], label[i]])
#sorting by decision value
db = sorted(db, key=itemgetter(0), reverse=True)
#calculate ROC
xy_arr = []
tp, fp = 0., 0. #assure float division
for i in range(len(db)):
if db[i][1]>0: #positive
tp+=1
else:
fp+=1
xy_arr.append([fp/neg,tp/pos])
#area under curve
aoc = 0.
prev_x = 0
for x,y in xy_arr:
if x != prev_x:
aoc += (x - prev_x) * y
prev_x = x
#begin gnuplot
if title == None:
title = output
#also write to file
g = gnuplot(output)
g.xlabel = "False Positive Rate"
g.ylabel = "True Positive Rate"
g.title = "ROC curve of %s (AUC = %.4f)" % (title,aoc)
g.plotline(xy_arr)
#display on screen
s = gnuplot('onscreen')
s.xlabel = "False Positive Rate"
s.ylabel = "True Positive Rate"
s.title = "ROC curve of %s (AUC = %.4f)" % (title,aoc)
s.plotline(xy_arr)
def check_gnuplot_exe():
global gnuplot_exe
gnuplot_exe = None
for g in gnuplot_exe_list:
if path.exists(g.replace('"','')):
gnuplot_exe=g
break
if gnuplot_exe == None:
print("You must add correct path of 'gnuplot' into gnuplot_exe_list")
raise SystemExit
def main():
check_gnuplot_exe()
if len(argv) <= 1:
print("Usage: %s [-v cv_fold | -T testing_file] [libsvm-options] training_file" % argv[0])
raise SystemExit
param,fold,train_file,test_file = proc_argv()
output_file = path.split(train_file)[1] + '-roc.png'
#read data
train_y, train_x = svm_read_problem(train_file)
if set(train_y) != set([1,-1]):
print("ROC is only applicable to binary classes with labels 1, -1")
raise SystemExit
#get decision value, with positive = label+
seed(0) #reset random seed
if test_file: #go with test_file
output_title = "%s on %s" % (path.split(test_file)[1], path.split(train_file)[1])
test_y, test_x = svm_read_problem(test_file)
if set(test_y) != set([1,-1]):
print("ROC is only applicable to binary classes with labels 1, -1")
raise SystemExit
deci,model = get_pos_deci(train_y, train_x, test_y, test_x, param)
plot_roc(deci, test_y, output_file, output_title)
else: #single file -> CV
output_title = path.split(train_file)[1]
deci = get_cv_deci(train_y, train_x, param, fold)
plot_roc(deci, train_y, output_file, output_title)
if __name__ == '__main__':
main()