首页 > 代码库 > 初识分类算法(2)------决策树ID3算法

初识分类算法(2)------决策树ID3算法

例子:分类:play or not ?(是/否)

            

目的:根据训练样本集S构建出一个决策树,然后未知分类样本通过决策树就得出分类。

问题:怎么构建决策树,从哪个节点开始(选择划分属性的问题)

方法:ID3(信息增益),C4.5(信息增益率),它们都是用来衡量给定属性区分训练样例的能力。

1. 为了理解信息增益,先来理解熵Entropy。熵是来度量样本的均一性的,即样本的纯度。样本的熵越大,代表信息的不确定性越大,信息就越混杂。

训练样本S:    

S相对与目标分类的熵为:

(p1,p2,,pn)=p1log2(p1)p2log2(p2)pnlog2(pn

2. 计算数据集由属性A(outlook)划分数据集后,数据集的熵。

3. 因此,由属性A划分之后的S相比于原始的S,信息增益GAIN(A)为

4.同理算出其余3个属性的划分之后的S相比于原始S的信息增益,谁的信息增益最大,就说明使用某个属性划分S之后,S的信息的确定性增加最多,同时熵减少的越多。

5.上例选择了outlook作为决策树的根节点,然后分别对outlook的3个属性,重复上面的操作,最终得出一颗完整的决策树。

6. 那么下面来看信息增益存在的一个问题:假设某个属性存在大量的不同值,如ID编号(在上面例子中加一列为ID,编号为a~n),在划分时将每个值成为一个结点,如下:

       那么S由属性ID划分之后的熵为0,因为信息已经确定了。因此S由属性ID划分之后相比于之前的S信息增益为max,因此在选择树节点的时候肯定会选择ID. 这样决策树在选择属性时,将偏向于选择该属性,但这肯定是不正确(导致过拟合)的。

7.因此C4.5算法采用信息增益率GainRation作为选择树节点的依据。

  IntrinsicInfo(A)=-5/14*log(-5/14)-4/14*log(-4/14)-5/14*log(-5/14)

      信息增益率作为一种补偿(Compensate)措施来解决信息增益所存在的问题,但是它也有可能导致过分补偿,而选择那些内在信息很小的属性,这一点可以尝试:首先,仅考虑那些信息增益超过平均值的属性,其次再比较信息增益。

8.python  实现

 8.1 计算熵的函数

 1 def calcShannonEnt(dataSet): 2     numEntropy = len(dataSet) 3     labelCounts = {} 4     for  exam in dataSet: #the the number of unique elements and their occurance 5         currentLabel = exam[-1] #输出结果yes or no 6         #print currentLabel 7         if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 8         labelCounts[currentLabel] += 1 9     #print labelCounts      --{‘Yes‘: 9, ‘No‘: 5}10     shannonEnt = 0.011     for key in labelCounts:12         #print labelCounts[key]13         prob = float(labelCounts[key])/numEntropy14         shannonEnt -= prob * log(prob,2) #log base 215     return shannonEnt
entropy 计算函数

8.2 选择信息增益最大的属性划分数据集

 1 def splitDataSet(dataSet, axis, value): 2     retDataSet = [] 3     for featVec in dataSet: 4         #print featVec[axis],value 5         if featVec[axis] == value: 6             reducedFeatVec = featVec[:axis]     #chop out axis used for splitting 7             #print reducedFeatVec 8             reducedFeatVec.extend(featVec[axis+1:]) 9             #print reducedFeatVec10             retDataSet.append(reducedFeatVec)11     return retDataSet12 def chooseBestFeatureToSplit(dataSet):13     numFeatures = len(dataSet[0]) - 1      #the last column is used for the labels14 #    print dataSet[0],numFeatures15     baseEntropy = calcShannonEnt(dataSet)16 #    print baseEntropy17     bestInfoGain = 0.0; bestFeature = -118     for i in range(numFeatures):        #iterate over all the features19         #print range(numFeatures)20         featList = [example[i] for example in dataSet]#create a list of all the examples of this feature21         #print example,example[i],featList22         uniqueVals = set(featList)       #get a set of unique values  得到唯一值23         #print uniqueVals24         newEntropy = 0.025         for value in uniqueVals:26             #print value27             subDataSet = splitDataSet(dataSet, i, value)28             #print subDataSet29             prob = len(subDataSet)/float(len(dataSet))30             #print len(subDataSet),len(dataSet),prob31             newEntropy += prob * calcShannonEnt(subDataSet)  32             print  calcShannonEnt(subDataSet),newEntropy33         infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy34         print infoGain35         if (infoGain > bestInfoGain):       #compare this to the best gain so far36             bestInfoGain = infoGain         #if better than current best, set to best37             bestFeature = i38     return bestFeature                      #returns an integer39 mydata,labels=createDataSet()40 #print mydata41 calcShannonEnt(mydata)42 chooseBestFeatureToSplit(mydata)
chooseBestFeatureToSplit(dataSet)

8.3 可视化展示决策树

 1 def createTree(dataSet,labels): 2     classList = [example[-1] for example in dataSet] 3     if classList.count(classList[0]) == len(classList):  4         return classList[0]#stop splitting when all of the classes are equal 5     if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet 6         return majorityCnt(classList) 7     bestFeat = chooseBestFeatureToSplit(dataSet) 8     bestFeatLabel = labels[bestFeat] 9     myTree = {bestFeatLabel:{}}10     del(labels[bestFeat])11     featValues = [example[bestFeat] for example in dataSet]12     uniqueVals = set(featValues)13     for value in uniqueVals:14         subLabels = labels[:]       #copy all of labels, so trees don‘t mess up existing labels15         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)16     return myTree         17 tree=createTree(mydata,labels)18 treePlotter.createPlot(tree)
plot

 

 

 

 

初识分类算法(2)------决策树ID3算法