-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcriterion.py
65 lines (52 loc) · 1.91 KB
/
criterion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Sep 8 21:25:01 2018
@author: Arpit
"""
import logger as lg
from statsFuncs import get_chi_square
"""
criterion object which finds the
best feature to split data on
"""
class Criterion:
def __init__(self, data):
self.data = data
self.selectBestCriterion()
"""
selects the best feature to split on
"""
def selectBestCriterion(self):
lg.main.debug("Selecting best criterion for split!")
maxInfoGain = float("-inf")
featsCnt = self.data.getFeatCnt()
lg.main.debug("Trying %d features", featsCnt)
for i in range(featsCnt):
splits = self.data.split(i) #split data on ith feature
infoGain = self.getInfoGain(splits)
lg.main.debug("Info gain %s for feature %d", infoGain, i)
if infoGain > maxInfoGain:
self.bestSplits = splits
self.bestFeature = i
maxInfoGain = infoGain
self.maxIG = maxInfoGain
lg.main.debug("Best feature for splitting is %d with maxIG %f\n", self.bestFeature, self.maxIG)
def getInfoGain(self, splits):
lg.main.debug("Getting information gain!")
result = self.data.getImpurity()
totalRows = self.data.getRowCnt()
for _, data in splits:
result -= (data.getRowCnt()/totalRows) * data.getImpurity()
return result
"""
returns False if splitting needs to be stopped
"""
def split(self, alpha):
splitsDataDist = [x.getDist() for _, x in self.bestSplits]
IG = get_chi_square(self.data.getDist(), splitsDataDist, alpha) #get chi square to decide if we should split further
if IG and alpha == 1 and self.maxIG == 0: IG = False
if not IG:
return False
else:
return self.bestSplits