首页 > 代码库 > 机器学习实战笔记--决策树

机器学习实战笔记--决策树

tree.py代码

  1 #encoding:utf-8
  2 from math import log
  3 import operator
  4 import treePlotter as tp
  5 
  6 
  7 def createDataSet():   #简单测试数据创建
  8     dataSet = [[1, 1, yes],
  9                [1, 1, yes],
 10                [1, 0, no],
 11                [0, 1, no],
 12                [0, 1, no]]
 13     labels = [no surfacing, flippers]
 14     # change to discrete values
 15     return dataSet, labels
 16 
 17 
 18 def calcShannonEnt(dataSet):   #计算给定数据集的香农熵
 19     numEntries = len(dataSet)
 20     labelCounts = {}
 21     for featVec in dataSet:
 22         currentLabel = featVec[-1]
 23         if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
 24         labelCounts[currentLabel] += 1
 25     shannonEnt = 0.0
 26     for key in labelCounts:
 27         prob = float(labelCounts[key]) / numEntries
 28         shannonEnt -= prob * log(prob, 2)
 29     return shannonEnt
 30 
 31 #按照给定特征划分数据集
 32 def splitDataSet(dataSet, axis, value):   #dataSet:数据集   axis:下标(用于指定哪个特征)  value:该特征的值
 33     retDataSet = []
 34     for featVec in dataSet:
 35         if featVec[axis] == value:
 36             reducedFeatVec = featVec[:axis]
 37             reducedFeatVec.extend(featVec[axis + 1:])
 38             retDataSet.append(reducedFeatVec)  #reducedFeatVec中没有指定的那个特征值了,注意append和extend的区别
 39     return retDataSet
 40 
 41 
 42 def chooseBestFeatureToSplit(dataSet):
 43     numFeatures = len(dataSet[0]) - 1
 44     baseEntropy = calcShannonEnt(dataSet)
 45     bestInfoGain = 0.0;
 46     bestFeature = -1
 47     for i in range(numFeatures):   #第i列
 48         featList = [example[i] for example in dataSet]
 49         uniqueVals = set(featList)   #创建唯一的分类标签列表
 50         newEntropy = 0.0
 51         for value in uniqueVals:   #计算每种分类方式的信息熵,并加到总的熵,一个特征可能有多个值
 52             subDataSet = splitDataSet(dataSet, i, value)
 53             prob = len(subDataSet) / float(len(dataSet))
 54             newEntropy += prob * calcShannonEnt(subDataSet)  #总的熵
 55         infoGain = baseEntropy - newEntropy  # 得到信息增益
 56         if (infoGain > bestInfoGain):  #如果更好,则更新
 57             bestInfoGain = infoGain
 58             bestFeature = i
 59     return bestFeature  #返回最好的第几列,整型
 60 
 61 
 62 def majorityCnt(classList):
 63     classCount = {}    #类似于map
 64     for vote in classList:   #统计分类名称出现的次数
 65         if vote not in classCount.keys(): classCount[vote] = 0
 66         classCount[vote] += 1
 67     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)   #排序
 68     return sortedClassCount[0][0]    #返回出现次数最多的分类名称
 69 
 70 
 71 def createTree(dataSet, labels):  #构建决策树
 72     classList = [example[-1] for example in dataSet]
 73     if classList.count(classList[0]) == len(classList):    #类别相同则停止划分
 74         return classList[0]
 75     if len(dataSet[0]) == 1:  # 遍历完所有特征值时返回最多的
 76         return majorityCnt(classList)
 77     bestFeat = chooseBestFeatureToSplit(dataSet)  #最佳划分
 78     bestFeatLabel = labels[bestFeat]   #最佳划分属性名
 79     myTree = {bestFeatLabel: {}}
 80     del (labels[bestFeat])   #删除该属性
 81     featValues = [example[bestFeat] for example in dataSet]
 82     uniqueVals = set(featValues)   #得到列表包含的所有特征值
 83     for value in uniqueVals:
 84         subLabels = labels[:]
 85         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
 86     return myTree #返回决策树
 87 
 88 
 89 def classify(inputTree, featLabels, testVec):
 90     firstStr = inputTree.keys()[0]   #inputTree根字符串
 91     secondDict = inputTree[firstStr]   #形如  0:‘yes‘, 1:{下一级树}
 92     featIndex = featLabels.index(firstStr)   #将标签字符串转化为索引
 93     key = testVec[featIndex]   #testVec当前属性下的值
 94     valueOfFeat = secondDict[key]  #值为key的下一级树
 95     if isinstance(valueOfFeat, dict):  #valueOfFeat为字典数据类型时,递归
 96         classLabel = classify(valueOfFeat, featLabels, testVec)
 97     else:
 98         classLabel = valueOfFeat   #否则就是当前结果
 99     return classLabel
100 
101 
102 def storeTree(inputTree, filename):   #决策树的存储
103     import pickle
104     fw = open(filename, w)
105     pickle.dump(inputTree, fw)
106     fw.close()
107 
108 
109 def grabTree(filename):  #决策树的读取
110     import pickle
111     fr = open(filename)
112     return pickle.load(fr)
113 
114 if __name__ == __main__:
115     # dataSet, labels = createDataSet()
116     # print dataSet
117     # print labels
118     # shannonEnt = calcShannonEnt(dataSet)
119     # print "香农熵为 %f" % (shannonEnt)
120     # myMat = splitDataSet(dataSet,0,1)
121     # print myMat
122     # index = chooseBestFeatureToSplit(dataSet)
123     # print index
124     #mytree = createTree(dataSet, labels)
125     # print "决策树:"
126     # print mytree
127     # myTree = tp.retrieveTree(0)
128     # print myTree
129     # storeTree(myTree,‘myTree.txt‘)
130     # myTree = grabTree(‘myTree.txt‘)
131     # print myTree
132     # print classify(myTree,labels,[1,0])
133 
134     #决策树预测隐形眼镜类型
135     fr = open(lenses.txt)
136     lenses = [line.strip().split(\t) for line in fr.readlines()]
137     lensesLabels = [age,prescript,astigmatic,tearRate]
138     lensesTree = createTree(lenses,lensesLabels)
139     print lensesTree
140     tp.createPlot(lensesTree)

treePlotter.py代码

  1 #encoding:utf-8
  2 import matplotlib.pyplot as plt
  3 
  4 decisionNode = dict(boxstyle="sawtooth", fc="0.8")
  5 leafNode = dict(boxstyle="round4", fc="0.8")
  6 arrow_args = dict(arrowstyle="<-")
  7 
  8 
  9 def getNumLeafs(myTree):  #得到树的叶子节点数
 10     numLeafs = 0
 11     firstStr = myTree.keys()[0]
 12     secondDict = myTree[firstStr]
 13     for key in secondDict.keys():
 14         if type(secondDict[
 15                     key]).__name__ == dict:  # test to see if the nodes are dictonaires, if not they are leaf nodes
 16             numLeafs += getNumLeafs(secondDict[key])
 17         else:
 18             numLeafs += 1
 19     return numLeafs
 20 
 21 
 22 def getTreeDepth(myTree):   #得到树的深度
 23     maxDepth = 0
 24     firstStr = myTree.keys()[0]
 25     secondDict = myTree[firstStr]
 26     for key in secondDict.keys():
 27         if type(secondDict[
 28                     key]).__name__ == dict:  # test to see if the nodes are dictonaires, if not they are leaf nodes
 29             thisDepth = 1 + getTreeDepth(secondDict[key])
 30         else:
 31             thisDepth = 1
 32         if thisDepth > maxDepth: maxDepth = thisDepth
 33     return maxDepth
 34 
 35 
 36 def plotNode(nodeTxt, centerPt, parentPt, nodeType):   #绘制带箭头的注解
 37     createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords=axes fraction,
 38                             xytext=centerPt, textcoords=axes fraction,
 39                             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
 40 
 41 
 42 def plotMidText(cntrPt, parentPt, txtString):  #在父子节点间填充文本信息
 43     xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
 44     yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
 45     createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
 46 
 47 #绘制树
 48 def plotTree(myTree, parentPt, nodeTxt):  # if the first key tells you what feat was split on
 49     numLeafs = getNumLeafs(myTree)  # this determines the x width of this tree
 50     depth = getTreeDepth(myTree)
 51     firstStr = myTree.keys()[0]  # the text label for this node should be this
 52     cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
 53     plotMidText(cntrPt, parentPt, nodeTxt)
 54     plotNode(firstStr, cntrPt, parentPt, decisionNode)
 55     secondDict = myTree[firstStr]
 56     plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
 57     for key in secondDict.keys():
 58         if type(secondDict[
 59                     key]).__name__ == dict:  # test to see if the nodes are dictonaires, if not they are leaf nodes
 60             plotTree(secondDict[key], cntrPt, str(key))  # recursion
 61         else:  # it‘s a leaf node print the leaf node
 62             plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
 63             plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
 64             plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
 65     plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
 66 
 67 
 68 # if you do get a dictonary you know it‘s a tree, and the first element will be another dict
 69 
 70 def createPlot(inTree):
 71     fig = plt.figure(1, facecolor=white)
 72     fig.clf()
 73     axprops = dict(xticks=[], yticks=[])
 74     createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)  # no ticks
 75     # createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
 76     plotTree.totalW = float(getNumLeafs(inTree))
 77     plotTree.totalD = float(getTreeDepth(inTree))
 78     plotTree.xOff = -0.5 / plotTree.totalW;
 79     plotTree.yOff = 1.0;
 80     plotTree(inTree, (0.5, 1.0), ‘‘)
 81     plt.show()
 82 
 83 
 84 # def createPlot():
 85 #    fig = plt.figure(1, facecolor=‘white‘)
 86 #    fig.clf()
 87 #    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
 88 #    plotNode(‘a decision node‘, (0.5, 0.1), (0.1, 0.5), decisionNode)
 89 #    plotNode(‘a leaf node‘, (0.8, 0.1), (0.3, 0.8), leafNode)
 90 #    plt.show()
 91 
 92 def retrieveTree(i):
 93     listOfTrees = [{no surfacing: {0: no, 1: {flippers: {0: no, 1: yes}}}},
 94                    {no surfacing: {0: no, 1: {flippers: {0: {head: {0: no, 1: yes}}, 1: no}}}}
 95                    ]
 96     return listOfTrees[i]
 97 
 98     # createPlot(thisTree)
 99 
100 if __name__ == __main__:
101     decisionNode = dict(boxstyle="sawtooth", fc="0.8")
102     leafNode = dict(boxstyle="round4", fc="0.8")
103     arrow_args = dict(arrowstyle="<-")
104     #createPlot()
105     myTree = retrieveTree(0)
106     createPlot(myTree)
107     # print myTree
108     # print getNumLeafs(myTree)
109     # print getTreeDepth(myTree)

 

机器学习实战笔记--决策树