首页 > 代码库 > ID3决策树算法实现(Python版)
ID3决策树算法实现(Python版)
1 # -*- coding:utf-8 -*- 2 3 from numpy import * 4 import numpy as np 5 import pandas as pd 6 from math import log 7 import operator 8 9 #计算数据集的香农熵 10 def calcShannonEnt(dataSet): 11 numEntries=len(dataSet) 12 labelCounts={} 13 #给所有可能分类创建字典 14 for featVec in dataSet: 15 currentLabel=featVec[-1] 16 if currentLabel not in labelCounts.keys(): 17 labelCounts[currentLabel]=0 18 labelCounts[currentLabel]+=1 19 shannonEnt=0.0 20 #以2为底数计算香农熵 21 for key in labelCounts: 22 prob = float(labelCounts[key])/numEntries 23 shannonEnt-=prob*log(prob,2) 24 return shannonEnt 25 26 27 #对离散变量划分数据集,取出该特征取值为value的所有样本 28 def splitDataSet(dataSet,axis,value): 29 retDataSet=[] 30 for featVec in dataSet: 31 if featVec[axis]==value: 32 reducedFeatVec=featVec[:axis] 33 reducedFeatVec.extend(featVec[axis+1:]) 34 retDataSet.append(reducedFeatVec) 35 return retDataSet 36 37 #对连续变量划分数据集,direction规定划分的方向, 38 #决定是划分出小于value的数据样本还是大于value的数据样本集 39 def splitContinuousDataSet(dataSet,axis,value,direction): 40 retDataSet=[] 41 for featVec in dataSet: 42 if direction==0: 43 if featVec[axis]>value: 44 reducedFeatVec=featVec[:axis] 45 reducedFeatVec.extend(featVec[axis+1:]) 46 retDataSet.append(reducedFeatVec) 47 else: 48 if featVec[axis]<=value: 49 reducedFeatVec=featVec[:axis] 50 reducedFeatVec.extend(featVec[axis+1:]) 51 retDataSet.append(reducedFeatVec) 52 return retDataSet 53 54 #选择最好的数据集划分方式 55 def chooseBestFeatureToSplit(dataSet,labels): 56 numFeatures=len(dataSet[0])-1 57 baseEntropy=calcShannonEnt(dataSet) 58 bestInfoGain=0.0 59 bestFeature=-1 60 bestSplitDict={} 61 for i in range(numFeatures): 62 featList=[example[i] for example in dataSet] 63 #对连续型特征进行处理 64 if type(featList[0]).__name__==‘float‘ or type(featList[0]).__name__==‘int‘: 65 #产生n-1个候选划分点 66 sortfeatList=sorted(featList) 67 splitList=[] 68 for j in range(len(sortfeatList)-1): 69 splitList.append((sortfeatList[j]+sortfeatList[j+1])/2.0) 70 71 bestSplitEntropy=10000 72 slen=len(splitList) 73 #求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点 74 for j in range(slen): 75 value=http://www.mamicode.com/splitList[j] 76 newEntropy=0.0 77 subDataSet0=splitContinuousDataSet(dataSet,i,value,0) 78 subDataSet1=splitContinuousDataSet(dataSet,i,value,1) 79 prob0=len(subDataSet0)/float(len(dataSet)) 80 newEntropy+=prob0*calcShannonEnt(subDataSet0) 81 prob1=len(subDataSet1)/float(len(dataSet)) 82 newEntropy+=prob1*calcShannonEnt(subDataSet1) 83 if newEntropy<bestSplitEntropy: 84 bestSplitEntropy=newEntropy 85 bestSplit=j 86 #用字典记录当前特征的最佳划分点 87 bestSplitDict[labels[i]]=splitList[bestSplit] 88 infoGain=baseEntropy-bestSplitEntropy 89 #对离散型特征进行处理 90 else: 91 uniqueVals=set(featList) 92 newEntropy=0.0 93 #计算该特征下每种划分的信息熵 94 for value in uniqueVals: 95 subDataSet=splitDataSet(dataSet,i,value) 96 prob=len(subDataSet)/float(len(dataSet)) 97 newEntropy+=prob*calcShannonEnt(subDataSet) 98 infoGain=baseEntropy-newEntropy 99 if infoGain>bestInfoGain: 100 bestInfoGain=infoGain 101 bestFeature=i 102 #若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理 103 #即是否小于等于bestSplitValue 104 if type(dataSet[0][bestFeature]).__name__==‘float‘ or type(dataSet[0][bestFeature]).__name__==‘int‘: 105 bestSplitValue=http://www.mamicode.com/bestSplitDict[labels[bestFeature]] 106 labels[bestFeature]=labels[bestFeature]+‘<=‘+str(bestSplitValue) 107 for i in range(shape(dataSet)[0]): 108 if dataSet[i][bestFeature]<=bestSplitValue: 109 dataSet[i][bestFeature]=1 110 else: 111 dataSet[i][bestFeature]=0 112 return bestFeature 113 114 #特征若已经划分完,节点下的样本还没有统一取值,则需要进行投票 115 def majorityCnt(classList): 116 classCount={} 117 for vote in classList: 118 if vote not in classCount.keys(): 119 classCount[vote]=0 120 classCount[vote]+=1 121 return max(classCount) 122 123 #主程序,递归产生决策树 124 def createTree(dataSet,labels,data_full,labels_full): 125 classList=[example[-1] for example in dataSet] 126 if classList.count(classList[0])==len(classList): 127 return classList[0] 128 if len(dataSet[0])==1: 129 return majorityCnt(classList) 130 bestFeat=chooseBestFeatureToSplit(dataSet,labels) 131 bestFeatLabel=labels[bestFeat] 132 myTree={bestFeatLabel:{}} 133 featValues=[example[bestFeat] for example in dataSet] 134 uniqueVals=set(featValues) 135 if type(dataSet[0][bestFeat]).__name__==‘str‘: 136 currentlabel=labels_full.index(labels[bestFeat]) 137 featValuesFull=[example[currentlabel] for example in data_full] 138 uniqueValsFull=set(featValuesFull) 139 del(labels[bestFeat]) 140 #针对bestFeat的每个取值,划分出一个子树。 141 for value in uniqueVals: 142 subLabels=labels[:] 143 if type(dataSet[0][bestFeat]).__name__==‘str‘: 144 uniqueValsFull.remove(value) 145 myTree[bestFeatLabel][value]=createTree(splitDataSet146 (dataSet,bestFeat,value),subLabels,data_full,labels_full) 147 if type(dataSet[0][bestFeat]).__name__==‘str‘: 148 for value in uniqueValsFull: 149 myTree[bestFeatLabel][value]=majorityCnt(classList) 150 return myTree 151 152 import matplotlib.pyplot as plt 153 decisionNode=dict(boxstyle="sawtooth",fc="0.8") 154 leafNode=dict(boxstyle="round4",fc="0.8") 155 arrow_args=dict(arrowstyle="<-") 156 157 158 #计算树的叶子节点数量 159 def getNumLeafs(myTree): 160 numLeafs=0 161 firstSides = list(myTree.keys()) 162 firstStr=firstSides[0] 163 secondDict=myTree[firstStr] 164 for key in secondDict.keys(): 165 if type(secondDict[key]).__name__==‘dict‘: 166 numLeafs+=getNumLeafs(secondDict[key]) 167 else: numLeafs+=1 168 return numLeafs 169 170 #计算树的最大深度 171 def getTreeDepth(myTree): 172 maxDepth=0 173 firstSides = list(myTree.keys()) 174 firstStr=firstSides[0] 175 secondDict=myTree[firstStr] 176 for key in secondDict.keys(): 177 if type(secondDict[key]).__name__==‘dict‘: 178 thisDepth=1+getTreeDepth(secondDict[key]) 179 else: thisDepth=1 180 if thisDepth>maxDepth: 181 maxDepth=thisDepth 182 return maxDepth 183 184 #画节点 185 def plotNode(nodeTxt,centerPt,parentPt,nodeType): 186 createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords=‘axes fraction‘,187 xytext=centerPt,textcoords=‘axes fraction‘,va="center", ha="center",188 bbox=nodeType,arrowprops=arrow_args) 189 190 #画箭头上的文字 191 def plotMidText(cntrPt,parentPt,txtString): 192 lens=len(txtString) 193 xMid=(parentPt[0]+cntrPt[0])/2.0-lens*0.002 194 yMid=(parentPt[1]+cntrPt[1])/2.0 195 createPlot.ax1.text(xMid,yMid,txtString) 196 197 def plotTree(myTree,parentPt,nodeTxt): 198 numLeafs=getNumLeafs(myTree) 199 depth=getTreeDepth(myTree) 200 firstSides = list(myTree.keys()) 201 firstStr=firstSides[0] 202 cntrPt=(plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff) 203 plotMidText(cntrPt,parentPt,nodeTxt) 204 plotNode(firstStr,cntrPt,parentPt,decisionNode) 205 secondDict=myTree[firstStr] 206 plotTree.y0ff=plotTree.y0ff-1.0/plotTree.totalD 207 for key in secondDict.keys(): 208 if type(secondDict[key]).__name__==‘dict‘: 209 plotTree(secondDict[key],cntrPt,str(key)) 210 else: 211 plotTree.x0ff=plotTree.x0ff+1.0/plotTree.totalW 212 plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode) 213 plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key)) 214 plotTree.y0ff=plotTree.y0ff+1.0/plotTree.totalD 215 216 def createPlot(inTree): 217 fig=plt.figure(1,facecolor=‘white‘) 218 fig.clf() 219 axprops=dict(xticks=[],yticks=[]) 220 createPlot.ax1=plt.subplot(111,frameon=False,**axprops) 221 plotTree.totalW=float(getNumLeafs(inTree)) 222 plotTree.totalD=float(getTreeDepth(inTree)) 223 plotTree.x0ff=-0.5/plotTree.totalW 224 plotTree.y0ff=1.0 225 plotTree(inTree,(0.5,1.0),‘‘) 226 plt.show() 227 228 df=pd.read_csv(‘watermelon_4_3.csv‘) 229 data=http://www.mamicode.com/df.values[:,1:].tolist() 230 data_full=data[:] 231 labels=df.columns.values[1:-1].tolist() 232 labels_full=labels[:] 233 myTree=createTree(data,labels,data_full,labels_full) 234 print(myTree) 235 createPlot(myTree)
最终结果如下:
{‘texture‘: {‘blur‘: 0, ‘little_blur‘: {‘touch‘: {‘soft_stick‘: 1, ‘hard_smooth‘: 0}}, ‘distinct‘: {‘density<=0.38149999999999995‘: {0: 1, 1: 0}}}}
得到的决策树如下:
参考资料:
《机器学习实战》
《机器学习》周志华著
ID3决策树算法实现(Python版)
声明:以上内容来自用户投稿及互联网公开渠道收集整理发布,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任,若内容有误或涉及侵权可进行投诉: 投诉/举报 工作人员会在5个工作日内联系你,一经查实,本站将立刻删除涉嫌侵权内容。