-
Notifications
You must be signed in to change notification settings - Fork 0
/
decision_tree.py
210 lines (178 loc) · 6.34 KB
/
decision_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# -*-coding:utf-8-*-
from math import log
import operator
def createDataset():
dataSet = [
[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']
]
# no surfacing :不浮出水面 flippers: 脚蹼
labels = ['no surfacing', 'flippers']
return dataSet, labels
myDat, labels = createDataset()
# myDat[0][-1]='maybe'
# print(myDat)
# 计算给定数据集的香农熵
def calcShannonEnt(dataSet):
# 获得数据集大小
numEntries = len(dataSet)
labelCounts = {}
for fecVec in dataSet:
# 取分类信息
currentLabel = fecVec[-1]
# print(currentLabel,'25')
# 如果当前分类不在字典labelCounts的键中,则让该键的值等于0
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
# 不管有没有该键,值都相加
labelCounts[currentLabel] += 1
# print(labelCounts,'30行')
shannoEnt = 0.0
# 遍历labelCounts字典的键
for key in labelCounts:
# 取出该字典所有的值,并让其除以数据集的大小 p(Xi)
prob = float(labelCounts[key]) / numEntries
# print(prob,'37')
shannoEnt -= prob * log(prob, 2)
# print(shannoEnt)
# 香农熵
return shannoEnt
# res=calcShannonEnt(myDat)
# print(res)
# dataSet:带划分的数据集 axis:划分数据集的特征 value:需要返回特征的值
def splitDataSet(dataSet, axis, value):
# print(dataSet,'51行',axis,value)
# 创建新的列表
retDataSet = []
# 获得数据中的数据
# print(dataSet,'44行')
for featVec in dataSet:
print(featVec[axis], value, axis,"57****")
# 判断数据中的该列特征数据是否等于该值
if featVec[axis] == value:
# print(featVec[:axis],'48行')
# 去掉axis的特征
reducedFeatVec = featVec[:axis]
# 则把该值之后的所有特征添加到reducedFeatVec中
reducedFeatVec.extend(featVec[axis + 1:])
# print(featVec[axis+1:],'60行---------------------')
retDataSet.append(reducedFeatVec)
# print(retDataSet,'59')
return retDataSet
# res1=splitDataSet(myDat,0,1)
# res2=splitDataSet(myDat,0,0)
# #选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
# 获取特征长度(除去最后一列分类)
numFeatures = len(dataSet[0]) - 1
# 原始数据集香农熵
baseEntropy = calcShannonEnt(dataSet) # 0.9709505944546686
# print(baseEntropy,'69')
# 最好的信息是0.0
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
# 取出每一个特征列的值
featList = [example[i] for example in dataSet]
# print(dataSet)
# print(featList,'75行')
# 将其转为集合去重(不能有重复)
uniqueVals = set(featList)
# print(set(dataSet[i]),87)
# print(uniqueVals)
newEntropy = 0.0
# print(uniqueVals,'83')
# 遍历集合,计算信息增益
for value in uniqueVals:
print(i, 95)
# 划分数据集 i从2开始截至
subDataSet = splitDataSet(dataSet, i, value)
# 计算数据集的新熵
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
# print('****************')
# print(infoGain,'90',i)
# print('****************')
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
# print(bestInfoGain,'96******************')
return bestFeature
# data=chooseBestFeatureToSplit(myDat)
# print(data)
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedclassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedclassCount[0][0]
# 创建决策树
def createTree(dataSet, labels):
# 类别列表
classList = [example[-1] for example in dataSet]
# print(classList,'116')
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
print(bestFeatLabel, 126)
myTree = {bestFeatLabel: {}}
del (labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
sublabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), sublabels)
return myTree
res = createTree(myDat, labels)
print(res)
import matplotlib.pyplot as plt
# 解决中文问题
from matplotlib.font_manager import FontProperties
font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14)
decisionNode = dict(boxstyle='sawtooth', fc="0.8")
leafNode = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle='<-')
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction',
va="center", bbox=nodeType, arrowprops=arrow_args, fontproperties=font)
def createPlot():
fig = plt.figure(1, facecolor='white')
fig.clf()
createPlot.ax1 = plt.subplot(111, frameon=False)
plotNode("决策节点", (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode("叶节点", (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
createPlot()
def getNumLeafs(myTree):
numleafs=0
firstStr=myTree.keys()[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=="dict":
numleafs+=getNumLeafs(secondDict[key])
# else:
numleafs+=1
return numleafs
def getTreeDepth(myTree):
maxDepth=0
firstStr=myTree.keys()[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=="dict":
thisDepth=1+getTreeDepth(secondDict[key])
else:
thisDepth=1
if thisDepth>maxDepth:
maxDepth=thisDepth
return maxDepth
getNumLeafs(res)
getTreeDepth(res)