首页 > 代码库 > Weka算法Classifier-trees-REPTree源码分析(二)
Weka算法Classifier-trees-REPTree源码分析(二)
(接上篇)
一、剪枝过程
上篇分析完了tree节点的构建过程,在REPTree.buildClassifier之后如果设置了剪枝选项,则还有一个剪枝和backfit过程。
if (!m_NoPruning) { m_Tree.insertHoldOutSet(prune); m_Tree.reducedErrorPrune(); m_Tree.backfitHoldOutSet(); }
重点卡一下reducedErrorPrune和backfitHoldOutSet过程。
二、Tree.reducedErrorPrune
protected double reducedErrorPrune() throws Exception { <span style="white-space:pre"> </span>//这个函数会返回该树及其子树的一个错误情况,如果是枚举类型返回的是分错的instance数量,数值类型返回的是与正确值的偏差的平方和 // 如果是叶子节点就不做任何操作 if (m_Attribute == -1) { return m_HoldOutError;//简单的说一下这个error怎么计算来的,使用<span style="font-size:18px;">insertHoldOutSet传入数据时会根据原先训练时的分布,来预测出传入数据的class,然后根据这个结果和真正的class值进行比对,就知道是否分的正确了</span> } //计算一下所有的子树的偏差 double errorTree = 0; for (int i = 0; i < m_Successors.length; i++) { errorTree += m_Successors[i].reducedErrorPrune(); } if (errorTree >= m_HoldOutError) { m_Attribute = -1;//如果子树偏差大于本身的偏差,那子树就没啥存在的意义了,直接去掉。 m_Successors = null; return m_HoldOutError; } else { return errorTree; } }可以看出,这个剪枝过程和J48相比还是简单不少的。
三、Tree.backfitHoldOutSet
protected void backfitHoldOutSet() throws Exception { // Insert instance into hold-out class distribution if (m_Info.classAttribute().isNominal()) { // Nominal case if (m_ClassProbs == null) { m_ClassProbs = new double[m_Info.numClasses()]; } System.arraycopy(m_Distribution, 0, m_ClassProbs, 0, m_Info.numClasses()); for (int i = 0; i < m_HoldOutDist.length; i++) { m_ClassProbs[i] += m_HoldOutDist[i]; } if (Utils.sum(m_ClassProbs) > 0) { Utils.normalize(m_ClassProbs); } else { m_ClassProbs = null; } } else { // Numeric case double sumOfWeightsTrainAndHoldout = m_Distribution[1] + m_HoldOutDist[0]; if (sumOfWeightsTrainAndHoldout <= 0) { return; } if (m_ClassProbs == null) { m_ClassProbs = new double[1]; } else { m_ClassProbs[0] *= m_Distribution[1]; } m_ClassProbs[0] += m_HoldOutDist[1]; m_ClassProbs[0] /= sumOfWeightsTrainAndHoldout; } // The process is recursive if (m_Attribute != -1) { for (int i = 0; i < m_Successors.length; i++) { m_Successors[i].backfitHoldOutSet(); } } }可以看出,就是一个根据新传入的数据集对原数据的分布进行重新计算,并且再对子树进行递归的调用backfit的过程,不再详细对代码进行注释了。
四、REPTree和J48的比较
同样都是分类树,REPTree和J48有很多不同点,下面简单的说一说这些差异。
1、对连续值排序的处理
J48在处理连续值的时候,每一个subset都要进行排序,而REPTree是先在主流程中对所有属性进行排序,并生成index传给Tree节点来进行处理的。
因此J48所耗时间比较长,而REPTree则占用较大内存(数据数量*数据属性列数量,因此也可以看到REPTree的代码中不断的有显式置空去尝试释放内存的操作),这是一个典型的时间和空间的tradeoff。
2、递归退出条件
J48的分裂停止条件有5个,
(1)所有的instances已经属于同一个分类(selectModel里)
(2)instances数量小于2*minNoObj(selectModel里)
(3)一个分裂产生的信息增益石0(selectModel里)
(4)对离散值进行分裂节点的计算时,超过一个的Bag里的instance数量小于minNoObj(spliter里)
(5)对连续值进行分裂计算时,有效instances数量小于2*minNoObj(spliter里)
REPTree的停止条件有4个
(1)训练集数量小于2*minNum
(2)如果枚举类型,且在一个类中
(3)如果数值类型,方差小于一个给定值
(4)达到最大深度
可以看出,主要的不同在于REPTree使用方差来判断连续值是否结束分裂。
3、节点选择方式
J48使用信息增益率,REPTree使用信息增益
4、剪枝与backfit
J48的剪枝较为复杂,分成了collapse()和prune()两个操作,而REPTree的剪枝从逻辑上讲只是J48的collapse操作,并没有子树上提等较为激进的剪枝策略。
J48没有backfit,REPTree有backfit,这是因为J48就自己独立的classifyInstance过程并不依赖样本集的分布,而J48的classifyInstance是调用基类过程,需要自己存储一个分布,进而使用backfit来防止过拟合。
Weka算法Classifier-trees-REPTree源码分析(二)