首页 > 代码库 > Weka算法Classifier-meta-AdditiveRegression源码分析

Weka算法Classifier-meta-AdditiveRegression源码分析


博主最近迷上了打怪物猎人,这片文章拖了很久才开始动笔


一、算法

AdditiveRegression,换个更出名一点的叫法可以称作GBDT(Grandient Boosting Decision Tree)梯度下降分类树,或者GBRT(Grandient Boosting Regression Tree)梯度下降回归树,是一种多分类器组合的算法,更确切的说,是属于Boosting算法。


谈到Boosting算法,就不能不提AdaBoost,参见之前我写的博客,可以看到AdaBoost的核心是级联分类器,使后一级分类器更加“关注”较为容易分错的数据,即后一级的分类器更有在易出错的数据集上进行训练。。


而GBDT作为Boosting算法,也是将多分类器进行级联训练,后一级的分类器则更多关注前面所有分类器预测结果与实际结果的残差,在这个残差上训练新的分类器,最终预测时将残差级联相加。


关于GBDT相关算法的公式推导可参考:

http://en.wikipedia.org/wiki/Gradient_boosting#Gradient_tree_boosting

http://www.360doc.com/content/12/0428/15/5874309_207282768.shtml


扯了这么多,下面简单说一下算法训练流程。

(1)输入训练集Data和基分类器的数量N

(2)使用训练集Data训练第1个基分类器

(3)for (int i=2;i<N;i++)

(4)使用前i-1个分类器进行预测,计算预测结果和训练数据的残差

(5)如果残差小于某个阈值,则退出循环。

(5)使用此残差训练第i个分类器

(6)转(3)


预测流程:

(1)根据输入数据,计算N个分类器的预测结果。

(2)将预测结果相加并返回。


可以看到,GBDT从原理上来讲并不复杂,“残差”的概念就用梯度来进行标示,抓住这一个线索看懂Wiki中的推导公式也并不是难事。复杂的是“如何证明其有效性”,这远超过本文可论证的范畴。


二、源码实现

就像之前所有的分类器一样,依然从buildClassifier入手。

(1)buildClassifier

public void buildClassifier(Instances data) throws Exception {

    super.buildClassifier(data);

    //additiveRegerssion只支持数值型数据。
    getCapabilities().testWithFail(data);

    //如果训练数据的class列为空,则去掉
    Instances newData = http://www.mamicode.com/new Instances(data);>
算法思想很简单,代码也很直观。

下面分析一下residualReplace函数。


(2)residualReplace

private Instances residualReplace(Instances data, Classifier c, 
				    boolean useShrinkage) throws Exception {
    double pred,residual;
    Instances newInst = new Instances(data);

    for (int i = 0; i < newInst.numInstances(); i++) {
      pred = c.classifyInstance(newInst.instance(i)); //进行预测
      if (useShrinkage) {
	pred *= getShrinkage();//使用shrinkage来防止过拟合
      }
      residual = newInst.instance(i).classValue() - pred;//算出残差
      newInst.instance(i).setClassValue(residual);//原始数据的class用残差替换
    }
    //    System.err.print(newInst);
    return newInst;
  }

什么是shrinkage?

shrinkage(缩减)的思想认为,每次走一小步逐渐逼近结果的效果,要比每次迈一大步很快逼近结果的方式更容易避免过拟合。即它不完全信任每一个棵残差树,它认为每棵树只学到了真理的一小部分,累加的时候只累加一小部分,通过多学几棵树弥补不足。(转自http://blog.csdn.net/w28971023/article/details/8240756)

可以看到,残差本身可以理解成“希望分类器结果前进的向量”,也就是梯度的含义,即包含了方向(分类器往哪个方向调整),也包含了长度(调整多少)。而shrinkage就是缩小这个长度到一定的比值,如10%,这样每次在这个向量方向上前进10%,以此来防止过拟合。

为什么shrinkage能防止过拟合?这又是一个看上去就复杂的不得了的问题啊。。。。


(3)classifyInstance

public double classifyInstance(Instance inst) throws Exception {

    double prediction = m_zeroR.classifyInstance(inst);

    if (!m_SuitableData) {
      return prediction;
    }
    
    for (int i = 0; i < m_NumIterationsPerformed; i++) {
      double toAdd = m_Classifiers[i].classifyInstance(inst);
      toAdd *= getShrinkage();
      prediction += toAdd;
    }

    return prediction;
  }

按照分类器顺序把残差相加得到最终结果。


四、总结

如果非要写个什么总结的话,那么我希望是以下几点:

(1)gbdt思想简单,实现起来也简单,效果非常理想。

(2)weka的additiveRegression是一个gbrt的简单实现,只能处理数值型数据。

(3)其实现的核心逻辑是用残差替换原有数据集的class列。

(4)可以选择性的使用shrinkage来防止过拟合。


Weka算法Classifier-meta-AdditiveRegression源码分析