-
Notifications
You must be signed in to change notification settings - Fork 11.5k
/
DecisionTree.py
executable file
·390 lines (338 loc) · 15.4 KB
/
DecisionTree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
#!/usr/bin/python
# coding:utf-8
'''
Created on Oct 12, 2010
Update on 2017-05-18
Decision Tree Source Code for Machine Learning in Action Ch. 3
Author: Peter Harrington/片刻
GitHub: https://github.com/apachecn/AiLearning
'''
from __future__ import print_function
print(__doc__)
import operator
from math import log
import decisionTreePlot as dtPlot
from collections import Counter
def createDataSet():
"""DateSet 基础数据集
Args:
无需传入参数
Returns:
返回数据集和对应的label标签
"""
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
# dataSet = [['yes'],
# ['yes'],
# ['no'],
# ['no'],
# ['no']]
# labels 露出水面 脚蹼
labels = ['no surfacing', 'flippers']
# change to discrete values
return dataSet, labels
def calcShannonEnt(dataSet):
"""calcShannonEnt(calculate Shannon entropy 计算给定数据集的香农熵)
Args:
dataSet 数据集
Returns:
返回 每一组feature下的某个分类下,香农熵的信息期望
"""
# -----------计算香农熵的第一种实现方式start--------------------------------------------------------------------------------
# 求list的长度,表示计算参与训练的数据量
numEntries = len(dataSet)
# 下面输出我们测试的数据集的一些信息
# 例如: <type 'list'> numEntries: 5 是下面的代码的输出
# print type(dataSet), 'numEntries: ', numEntries
# 计算分类标签label出现的次数
labelCounts = {}
# the the number of unique elements and their occurance
for featVec in dataSet:
# 将当前实例的标签存储,即每一行数据的最后一个数据代表的是标签
currentLabel = featVec[-1]
# 为所有可能的分类创建字典,如果当前的键值不存在,则扩展字典并将当前键值加入字典。每个键值都记录了当前类别出现的次数。
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
# print '-----', featVec, labelCounts
# 对于label标签的占比,求出label标签的香农熵
shannonEnt = 0.0
for key in labelCounts:
# 使用所有类标签的发生频率计算类别出现的概率。
prob = float(labelCounts[key])/numEntries
# log base 2
# 计算香农熵,以 2 为底求对数
shannonEnt -= prob * log(prob, 2)
# print '---', prob, prob * log(prob, 2), shannonEnt
# -----------计算香农熵的第一种实现方式end--------------------------------------------------------------------------------
# # -----------计算香农熵的第二种实现方式start--------------------------------------------------------------------------------
# # 统计标签出现的次数
# label_count = Counter(data[-1] for data in dataSet)
# # 计算概率
# probs = [p[1] / len(dataSet) for p in label_count.items()]
# # 计算香农熵
# shannonEnt = sum([-p * log(p, 2) for p in probs])
# # -----------计算香农熵的第二种实现方式end--------------------------------------------------------------------------------
return shannonEnt
def splitDataSet(dataSet, index, value):
"""splitDataSet(通过遍历dataSet数据集,求出index对应的colnum列的值为value的行)
就是依据index列进行分类,如果index列的数据等于 value的时候,就要将 index 划分到我们创建的新的数据集中
Args:
dataSet 数据集 待划分的数据集
index 表示每一行的index列 划分数据集的特征
value 表示index列对应的value值 需要返回的特征的值。
Returns:
index列为value的数据集【该数据集需要排除index列】
"""
# -----------切分数据集的第一种方式 start------------------------------------
retDataSet = []
for featVec in dataSet:
# index列为value的数据集【该数据集需要排除index列】
# 判断index列的值是否为value
if featVec[index] == value:
# chop out index used for splitting
# [:index]表示前index行,即若 index 为2,就是取 featVec 的前 index 行
reducedFeatVec = featVec[:index]
'''
请百度查询一下: extend和append的区别
list.append(object) 向列表中添加一个对象object
list.extend(sequence) 把一个序列seq的内容添加到列表中
1、使用append的时候,是将new_media看作一个对象,整体打包添加到music_media对象中。
2、使用extend的时候,是将new_media看作一个序列,将这个序列和music_media序列合并,并放在其后面。
result = []
result.extend([1,2,3])
print result
result.append([4,5,6])
print result
result.extend([7,8,9])
print result
结果:
[1, 2, 3]
[1, 2, 3, [4, 5, 6]]
[1, 2, 3, [4, 5, 6], 7, 8, 9]
'''
reducedFeatVec.extend(featVec[index+1:])
# [index+1:]表示从跳过 index 的 index+1行,取接下来的数据
# 收集结果值 index列为value的行【该行需要排除index列】
retDataSet.append(reducedFeatVec)
# -----------切分数据集的第一种方式 end------------------------------------
# # -----------切分数据集的第二种方式 start------------------------------------
# retDataSet = [data for data in dataSet for i, v in enumerate(data) if i == axis and v == value]
# # -----------切分数据集的第二种方式 end------------------------------------
return retDataSet
def chooseBestFeatureToSplit(dataSet):
"""chooseBestFeatureToSplit(选择最好的特征)
Args:
dataSet 数据集
Returns:
bestFeature 最优的特征列
"""
# -----------选择最优特征的第一种方式 start------------------------------------
# 求第一行有多少列的 Feature, 最后一列是label列嘛
numFeatures = len(dataSet[0]) - 1
# label的信息熵
baseEntropy = calcShannonEnt(dataSet)
# 最优的信息增益值, 和最优的Featurn编号
bestInfoGain, bestFeature = 0.0, -1
# iterate over all the features
for i in range(numFeatures):
# create a list of all the examples of this feature
# 获取每一个实例的第i+1个feature,组成list集合
featList = [example[i] for example in dataSet]
# get a set of unique values
# 获取剔重后的集合,使用set对list数据进行去重
uniqueVals = set(featList)
# 创建一个临时的信息熵
newEntropy = 0.0
# 遍历某一列的value集合,计算该列的信息熵
# 遍历当前特征中的所有唯一属性值,对每个唯一属性值划分一次数据集,计算数据集的新熵值,并对所有唯一特征值得到的熵求和。
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
# gain[信息增益]: 划分数据集前后的信息变化, 获取信息熵最大的值
# 信息增益是熵的减少或者是数据无序度的减少。最后,比较所有特征中的信息增益,返回最好特征划分的索引值。
infoGain = baseEntropy - newEntropy
print('infoGain=', infoGain, 'bestFeature=', i, baseEntropy, newEntropy)
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
# -----------选择最优特征的第一种方式 end------------------------------------
# # -----------选择最优特征的第二种方式 start------------------------------------
# # 计算初始香农熵
# base_entropy = calcShannonEnt(dataSet)
# best_info_gain = 0
# best_feature = -1
# # 遍历每一个特征
# for i in range(len(dataSet[0]) - 1):
# # 对当前特征进行统计
# feature_count = Counter([data[i] for data in dataSet])
# # 计算分割后的香农熵
# new_entropy = sum(feature[1] / float(len(dataSet)) * calcShannonEnt(splitDataSet(dataSet, i, feature[0])) \
# for feature in feature_count.items())
# # 更新值
# info_gain = base_entropy - new_entropy
# print('No. {0} feature info gain is {1:.3f}'.format(i, info_gain))
# if info_gain > best_info_gain:
# best_info_gain = info_gain
# best_feature = i
# return best_feature
# # -----------选择最优特征的第二种方式 end------------------------------------
def majorityCnt(classList):
"""majorityCnt(选择出现次数最多的一个结果)
Args:
classList label列的集合
Returns:
bestFeature 最优的特征列
"""
# -----------majorityCnt的第一种方式 start------------------------------------
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
# 倒叙排列classCount得到一个字典集合,然后取出第一个就是结果(yes/no),即出现次数最多的结果
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
# print 'sortedClassCount:', sortedClassCount
return sortedClassCount[0][0]
# -----------majorityCnt的第一种方式 end------------------------------------
# # -----------majorityCnt的第二种方式 start------------------------------------
# major_label = Counter(classList).most_common(1)[0]
# return major_label
# # -----------majorityCnt的第二种方式 end------------------------------------
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
# 如果数据集的最后一列的第一个值出现的次数=整个集合的数量,也就说只有一个类别,就只直接返回结果就行
# 第一个停止条件: 所有的类标签完全相同,则直接返回该类标签。
# count() 函数是统计括号中的值在list中出现的次数
if classList.count(classList[0]) == len(classList):
return classList[0]
# 如果数据集只有1列,那么最初出现label次数最多的一类,作为结果
# 第二个停止条件: 使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组。
if len(dataSet[0]) == 1:
return majorityCnt(classList)
# 选择最优的列,得到最优列对应的label含义
bestFeat = chooseBestFeatureToSplit(dataSet)
# 获取label的名称
bestFeatLabel = labels[bestFeat]
# 初始化myTree
myTree = {bestFeatLabel: {}}
# 注: labels列表是可变对象,在PYTHON函数中作为参数时传址引用,能够被全局修改
# 所以这行代码导致函数外的同名变量被删除了元素,造成例句无法执行,提示'no surfacing' is not in list
del(labels[bestFeat])
# 取出最优列,然后它的branch做分类
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
# 求出剩余的标签label
subLabels = labels[:]
# 遍历当前选择特征包含的所有属性值,在每个数据集划分上递归调用函数createTree()
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
# print 'myTree', value, myTree
return myTree
def classify(inputTree, featLabels, testVec):
"""classify(给输入的节点,进行分类)
Args:
inputTree 决策树模型
featLabels Feature标签对应的名称
testVec 测试输入的数据
Returns:
classLabel 分类的结果值,需要映射label才能知道名称
"""
# 获取tree的根节点对于的key值
firstStr = inputTree.keys()[0]
# 通过key得到根节点对应的value
secondDict = inputTree[firstStr]
# 判断根节点名称获取根节点在label中的先后顺序,这样就知道输入的testVec怎么开始对照树来做分类
featIndex = featLabels.index(firstStr)
# 测试数据,找到根节点对应的label位置,也就知道从输入的数据的第几位来开始分类
key = testVec[featIndex]
valueOfFeat = secondDict[key]
print('+++', firstStr, 'xxx', secondDict, '---', key, '>>>', valueOfFeat)
# 判断分枝是否结束: 判断valueOfFeat是否是dict类型
if isinstance(valueOfFeat, dict):
classLabel = classify(valueOfFeat, featLabels, testVec)
else:
classLabel = valueOfFeat
return classLabel
def storeTree(inputTree, filename):
import pickle
# -------------- 第一种方法 start --------------
fw = open(filename, 'wb')
pickle.dump(inputTree, fw)
fw.close()
# -------------- 第一种方法 end --------------
# -------------- 第二种方法 start --------------
with open(filename, 'wb') as fw:
pickle.dump(inputTree, fw)
# -------------- 第二种方法 start --------------
def grabTree(filename):
import pickle
fr = open(filename,'rb')
return pickle.load(fr)
def fishTest():
# 1.创建数据和结果标签
myDat, labels = createDataSet()
# print myDat, labels
# 计算label分类标签的香农熵
# calcShannonEnt(myDat)
# # 求第0列 为 1/0的列的数据集【排除第0列】
# print '1---', splitDataSet(myDat, 0, 1)
# print '0---', splitDataSet(myDat, 0, 0)
# # 计算最好的信息增益的列
# print chooseBestFeatureToSplit(myDat)
import copy
myTree = createTree(myDat, copy.deepcopy(labels))
print(myTree)
# [1, 1]表示要取的分支上的节点位置,对应的结果值
print(classify(myTree, labels, [1, 1]))
# 获得树的高度
print(get_tree_height(myTree))
# 画图可视化展现
dtPlot.createPlot(myTree)
def ContactLensesTest():
"""
Desc:
预测隐形眼镜的测试代码
Args:
none
Returns:
none
"""
# 加载隐形眼镜相关的 文本文件 数据
fr = open('data/3.DecisionTree/lenses.txt')
# 解析数据,获得 features 数据
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
# 得到数据的对应的 Labels
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
# 使用上面的创建决策树的代码,构造预测隐形眼镜的决策树
lensesTree = createTree(lenses, lensesLabels)
print(lensesTree)
# 画图可视化展现
dtPlot.createPlot(lensesTree)
def get_tree_height(tree):
"""
Desc:
递归获得决策树的高度
Args:
tree
Returns:
树高
"""
if not isinstance(tree, dict):
return 1
child_trees = tree.values()[0].values()
# 遍历子树, 获得子树的最大高度
max_height = 0
for child_tree in child_trees:
child_tree_height = get_tree_height(child_tree)
if child_tree_height > max_height:
max_height = child_tree_height
return max_height + 1
if __name__ == "__main__":
fishTest()
# ContactLensesTest()