-
Notifications
You must be signed in to change notification settings - Fork 0
/
nbclassify.py
67 lines (58 loc) · 2.14 KB
/
nbclassify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import sys
import timeit
global numOfDocuments
global vocabSize
class Classify:
def __init__(self, name):
self.name = name
self.totalNumOfWords = 0;
self.probWordGivenClass ={}
self.probClass = 0;
def classify():
#start = timeit.default_timer()
classes = []
modelfile = open(sys.argv[1],"r")
vocabSize = int(modelfile.readline().split()[1])
# for CLASS 1
className, prob, count = modelfile.readline().split()
newClass = Classify(className)
newClass.probClass = float(prob)
newClass.totalNumOfWords = int(count)
classes.append(newClass)
# for CLASS 2
className, prob, count = modelfile.readline().split()
newClass = Classify(className)
newClass.probClass = float(prob)
newClass.totalNumOfWords = int (count)
classes.append(newClass)
#read the heading n skip it
modelfile.readline()
for line in modelfile:
wordProb = line.split()
for currClass in classes:
currClass.probWordGivenClass[wordProb[0]] = float(wordProb[classes.index(currClass )+1])
modelfile.close()
testfile = open(sys.argv[2],"r")
#outputfile = open(sys.argv[2]+".out","w")
for line in testfile:
words = line.rstrip().split()
max_probability, maxClass = None, ""
for currClass in classes:
classProb = currClass.probClass
for w in words:
#print(w)
word_freq = (int) (w.split(":")[1])
curr_word = w.split(":")[0]
if curr_word in currClass.probWordGivenClass:
classProb += currClass.probWordGivenClass[curr_word] * word_freq
#else:
# classProb += - math.log10(classes[currClass].totalNumOfWords + vocabSize)
if(max_probability == None or max_probability < classProb):
max_probability, maxClass = classProb, currClass.name
print(maxClass)
#outputfile.write(maxClass)
#outputfile.write("\n");
#stop = timeit.default_timer()
#print(stop - start)
if __name__ == '__main__':
classify()