首页 > 代码库 > Weka算法Classifier-trees-REPTree源码分析(二)

Weka算法Classifier-trees-REPTree源码分析(二)

(接上篇)


一、剪枝过程

上篇分析完了tree节点的构建过程,在REPTree.buildClassifier之后如果设置了剪枝选项,则还有一个剪枝和backfit过程。

    if (!m_NoPruning) {
      m_Tree.insertHoldOutSet(prune);
      m_Tree.reducedErrorPrune();
      m_Tree.backfitHoldOutSet();
    }


其中insertHoldOutSet就是把剪枝用到的数据集传进去,不具体的区跟代码了。

重点卡一下reducedErrorPrune和backfitHoldOutSet过程。


二、Tree.reducedErrorPrune

    protected double reducedErrorPrune() throws Exception {
<span style="white-space:pre">	</span>//这个函数会返回该树及其子树的一个错误情况,如果是枚举类型返回的是分错的instance数量,数值类型返回的是与正确值的偏差的平方和
      // 如果是叶子节点就不做任何操作
      if (m_Attribute == -1) {
	return m_HoldOutError;//简单的说一下这个error怎么计算来的,使用<span style="font-size:18px;">insertHoldOutSet传入数据时会根据原先训练时的分布,来预测出传入数据的class,然后根据这个结果和真正的class值进行比对,就知道是否分的正确了</span>
      }

      //计算一下所有的子树的偏差
      double errorTree = 0;
      for (int i = 0; i < m_Successors.length; i++) {
	errorTree += m_Successors[i].reducedErrorPrune();
      }

      if (errorTree >= m_HoldOutError) {
	m_Attribute = -1;//如果子树偏差大于本身的偏差,那子树就没啥存在的意义了,直接去掉。
	m_Successors = null;
	return m_HoldOutError;
      } else {
	return errorTree;
      }
    }
可以看出,这个剪枝过程和J48相比还是简单不少的。


三、Tree.backfitHoldOutSet

protected void backfitHoldOutSet() throws Exception {
      
      // Insert instance into hold-out class distribution
      if (m_Info.classAttribute().isNominal()) {
	
	// Nominal case
	if (m_ClassProbs == null) {
	  m_ClassProbs = new double[m_Info.numClasses()];
	}
	System.arraycopy(m_Distribution, 0, m_ClassProbs, 0, m_Info.numClasses());
        for (int i = 0; i < m_HoldOutDist.length; i++) {
          m_ClassProbs[i] += m_HoldOutDist[i];
        }
        if (Utils.sum(m_ClassProbs) > 0) {
          Utils.normalize(m_ClassProbs);
        } else {
          m_ClassProbs = null;
        }
      } else {
	
	// Numeric case
        double sumOfWeightsTrainAndHoldout = m_Distribution[1] + m_HoldOutDist[0];
        if (sumOfWeightsTrainAndHoldout <= 0) {
          return;
        }
	if (m_ClassProbs == null) {
	  m_ClassProbs = new double[1];
	} else {
          m_ClassProbs[0] *= m_Distribution[1];
        }
	m_ClassProbs[0] += m_HoldOutDist[1];
	m_ClassProbs[0] /= sumOfWeightsTrainAndHoldout;
      }	
      
      // The process is recursive
      if (m_Attribute != -1) {
        for (int i = 0; i < m_Successors.length; i++) {
          m_Successors[i].backfitHoldOutSet();
        }
      }
    }
可以看出,就是一个根据新传入的数据集对原数据的分布进行重新计算,并且再对子树进行递归的调用backfit的过程,不再详细对代码进行注释了。


四、REPTree和J48的比较

同样都是分类树,REPTree和J48有很多不同点,下面简单的说一说这些差异。

1、对连续值排序的处理

J48在处理连续值的时候,每一个subset都要进行排序,而REPTree是先在主流程中对所有属性进行排序,并生成index传给Tree节点来进行处理的。

因此J48所耗时间比较长,而REPTree则占用较大内存(数据数量*数据属性列数量,因此也可以看到REPTree的代码中不断的有显式置空去尝试释放内存的操作),这是一个典型的时间和空间的tradeoff。

2、递归退出条件

J48的分裂停止条件有5个,

(1)所有的instances已经属于同一个分类(selectModel里)

(2)instances数量小于2*minNoObj(selectModel里)

(3)一个分裂产生的信息增益石0(selectModel里)

(4)对离散值进行分裂节点的计算时,超过一个的Bag里的instance数量小于minNoObj(spliter里)

(5)对连续值进行分裂计算时,有效instances数量小于2*minNoObj(spliter里)

REPTree的停止条件有4个

(1)训练集数量小于2*minNum

(2)如果枚举类型,且在一个类中

(3)如果数值类型,方差小于一个给定值

(4)达到最大深度

可以看出,主要的不同在于REPTree使用方差来判断连续值是否结束分裂。

3、节点选择方式

J48使用信息增益率,REPTree使用信息增益

4、剪枝与backfit

J48的剪枝较为复杂,分成了collapse()和prune()两个操作,而REPTree的剪枝从逻辑上讲只是J48的collapse操作,并没有子树上提等较为激进的剪枝策略。

J48没有backfit,REPTree有backfit,这是因为J48就自己独立的classifyInstance过程并不依赖样本集的分布,而J48的classifyInstance是调用基类过程,需要自己存储一个分布,进而使用backfit来防止过拟合。







Weka算法Classifier-trees-REPTree源码分析(二)