首页 > 代码库 > 《机器学习实战》——决策树
《机器学习实战》——决策树
原理(ID3):
依次选定每个特征,计算信息增益(基本信息熵-当前信息熵),选择信息增益最大的一个作为最佳特征;
以该特征作为树的根节点,以该最佳特征的每一个值作为分支,建立子树;
重复上述过程,直到:1) 所有类别一致 2) 特征用尽
优点:
简单容易理解;
可处理有缺失值的特征、非数值型数据;
与训练数据拟合很好,复杂度小
缺点:
选择特征时候需要多次扫描与排序,适合常驻内存的小数据集;
容易过拟合;
改进:
ID3算法偏向取值较多的特征(宽且浅的树),适合离散型数据,难处理连续型数据。香农熵里面概率的统计根本上是在统计不同分类的概率。
C4.5用信息增益比筛选特征,分母为-∑p*log(p), p为: 特征A某个取值包含的数据行数/总数据行数,可处理缺失值和连续型数据
CART用Gini系数筛选特征。同时是二叉决策树,能处理离散特征、连续特征、分类问题、回归问题。
Gini系数(杂质度量法):Gini(A)=1-∑(Pk)2. Pk表示观测点中k类的概率,当Gini(A)=0时只有一类,当所有类同概率出现时最大Gini(A)=(C-1)C/2。
如果目标变量是标称的,并且是具有两个以上的类别,则CART可能考虑将目标类别合并成两个超类别(双化);
如果目标变量是连续的,则CART算法找出一组基于树的回归方程来预测目标变量。
事后剪枝,停止条件是:1) 样本个数小于预定阀值 2) 样本的Gini系数小于预定阀值 3)没有更多特征
代码:
1 #coding: utf-8 2 from __future__ import division 3 from numpy import * 4 5 6 class myClass(object): 7 def __init__(self): 8 group, labels = self.loadData() 9 myTree = self.createTree(group, labels) 10 print myTree 11 12 def loadData(self): 13 group = [[1,1,"yes"], 14 [1,1,"yes"], 15 [1,0,"no"], 16 [0,1,"no"], 17 [0,1,"no"], 18 ] 19 labels = ["no surfacing", "flippers"] # 表示特征的含义:海洋生物不露出水面是否可以生存,是否有脚蹼 20 return group, labels 21 22 def calShannonEnt(self, group): 23 numEntrories = len(group) 24 labelCount = dict() 25 for feaVec in group: 26 currentLabel = feaVec[-1] 27 labelCount[currentLabel] = labelCount.get(currentLabel, 0) + 1 28 shannonEnt = 0.0 # 用之前定义 29 for key in labelCount.keys(): 30 prob = labelCount[key]/numEntrories 31 shannonEnt -= prob * log(prob) 32 return shannonEnt 33 34 def splitDataSet(self, group, axis, value): # 特征所有取值的信息熵之和才有意义,这里只计算条件熵。 35 retDataSet = [] 36 for feaVec in group: 37 if feaVec[axis] == value: 38 retDataSet.append(feaVec[:axis] + feaVec[axis+1:]) 39 return retDataSet 40 41 def chooseBestFeature(self, group): 42 numFeatures = len(group[0]) - 1 43 baseEntrory = self.calShannonEnt(group) 44 bestInfoGain = 0.0; bestFeature = -1 45 for i in range(numFeatures): 46 uniqueVals = set([it[i] for it in group]) 47 newEntrory = 0.0 48 for value in uniqueVals: 49 subGroup = self.splitDataSet(group, i, value) 50 prob = len(subGroup)/len(group) 51 newEntrory += prob * self.calShannonEnt(subGroup) 52 infoGain = baseEntrory - newEntrory 53 if (infoGain > bestInfoGain): 54 bestInfoGain = infoGain; bestFeature = i 55 return bestFeature 56 57 def majority(self, classList): 58 classCountDict = {} 59 for vote in classList: 60 classCountDict[vote] = classCountDict.get(vote, 0) + 1 61 return sorted(classCountDict.items(), key = lambda x:x[1], reverse = True)[0][0] 62 63 def createTree(self, myGroup, labels): 64 # 1.两个终止条件; 2.建立根树(求得最佳根特征)并从labels中删除根标签(取得根标签); 3.根据根特征的每个值建立子树(取得唯一特征值) 65 classList = [it[-1] for it in myGroup] 66 if classList.count(classList[0]) == len(classList): 67 return classList[0] 68 if len(myGroup[0]) == 1: 69 return self.majorityCnt(classList) 70 rootFeature = self.chooseBestFeature(myGroup) 71 rootLabel = labels[rootFeature] 72 myTree = {rootLabel:{}} 73 del(labels[rootFeature]) 74 uniqueVals = set([it[rootFeature] for it in myGroup]) 75 for val in uniqueVals: # 开始递归创建子树 76 subLabels = labels[:] # 每次都要定义新的subLabels 77 myTree[rootLabel][val] = self.createTree(self.splitDataSet(myGroup, rootFeature, val), subLabels) # 函数作为参数传入 78 return myTree 79 80 81 if __name__ == ‘__main__‘: 82 A = myClass()
《机器学习实战》——决策树
声明:以上内容来自用户投稿及互联网公开渠道收集整理发布,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任,若内容有误或涉及侵权可进行投诉: 投诉/举报 工作人员会在5个工作日内联系你,一经查实,本站将立刻删除涉嫌侵权内容。