首页 > 代码库 > OpenCV HaarTraining代码解析(二)cvCreateMTStumpClassifier(建立决策树)

OpenCV HaarTraining代码解析(二)cvCreateMTStumpClassifier(建立决策树)

HaarTraining关键的部分是建立基分类器classifier,OpenCV中所采用的是CART(决策树的一种):通过调用cvCreateMTStumpClassifier来完成。

这里我讨论利用回归的方法来分裂结点,分类的方法只是在分裂结点的方法与之不同而已。


cvCreateMTStumpClassifier

	//设置决策树分类误差计算方法
    stumperror = (int) ((CvMTStumpTrainParams*) trainParams)->error;
	
	//设置class step和ydata
    ydata = http://www.mamicode.com/trainClasses->data.ptr;>可能研究代码到这里的朋友仍然不清楚idxCache和valCache的作用。我这里做一点简单的说明:

valCache是设置在训练前有多少特征值被提前算出存放在内存中,idxCache是valCache中每种特征按特征值从小到大排列的样本的序号。内存大小是通过运行程序的命令行参数设置的。在cvHaartraining.cpp中我们可以找到这句话,其中float和short,分别是valCache和idxCache存放内容的基本类型。

	//1MB == 1048576B  计算一个样本中有多少个特征能被pre计算放在内存中
    numprecalculated = (int) ( ((size_t) mem) * ((size_t) 1048576) /
        ( ((size_t) (npos + nneg)) * (sizeof( float ) + sizeof( short )) ) );


为了方便理解,我把两者的内存模型画了出来



要注意idxCache中每行的index排列是示意图。比如第一行代表feature1从小到大的index顺序,从图中可以看出,sample1的特征值feature1 < sample0的特征值feature1<...<sample n < sample n-1。利用idxCache数组我们可以方便按特征值的从小到大遍历valCache,并且节省了空间。从float->short。


理解了这两个cache之后我们再回到上面的代码,可以发现这里做得只是设置步长和cache首地址的一些操作,为下面开始的遍历做好准备。


跳过一些变量的初始化步骤,我们来到构建决策树stump的部分,并且为了方便阅读核心代码,去掉了其他一些基于移植的代码

 while( t_compidx < n )
        {
			//选择计算前100种特征
            t_n = portion;
            if( t_compidx < datan )
            {
                t_n = ( t_n < (datan - t_compidx) ) ? t_n : (datan - t_compidx);
                t_data = http://www.mamicode.com/data;>这里datan代表的是一个检测窗口包含的特征数目,portion代表以多少的行为单位进行计算。每个循环选取valCache中的portion行进行计算,应该是为了发挥并行计算的优势,假如设置了并行计算的宏的话。


findStumpThreshold_32[stumperror]是一个函数指针,利用这个函数我们可以选择某个样本的某个特征值作为决策树的一个结点。这里结点分裂方法我选择的是最小残差和的方法。即统计利用某个特征值进行分类后,左右子树中类间的残差之和。最小残差和对应的特征值就是满足要求的结点。

OpenCV1.0中利用宏定义的方式实现了这个函数

#define ICV_DEF_FIND_STUMP_THRESHOLD( suffix, type, error )                              CV_BOOST_IMPL int icvFindStumpThreshold_##suffix(                                                      uchar* data, size_t datastep,                                                            uchar* wdata, size_t wstep,                                                              uchar* ydata, size_t ystep,                                                              uchar* idxdata, size_t idxstep, int num,                                                 float* lerror,                                                                           float* rerror,                                                                           float* threshold, float* left, float* right,                                             float* sumw, float* sumwy, float* sumwyy )                                       {                                                                                            int found = 0;                                                                           float wyl  = 0.0F;                                                                       float wl   = 0.0F;                                                                       float wyyl = 0.0F;                                                                       float wyr  = 0.0F;                                                                       float wr   = 0.0F;                                                                                                                                                                float curleft  = 0.0F;                                                                   float curright = 0.0F;                                                                   float* prevval = NULL;                                                                   float* curval  = NULL;                                                                   float curlerror = 0.0F;                                                                  float currerror = 0.0F;                                                                  float wposl;                                                                             float wposr;                                                                                                                                                                      int i = 0;                                                                               int idx = 0;                                                                                                                                                                      wposl = wposr = 0.0F;                                                                    if( *sumw == FLT_MAX )                                                                   {                                                                                            /* calculate sums */                                                                     float *y = NULL;                                                                         float *w = NULL;                                                                         float wy = 0.0F;                                                                                                                                                                  *sumw   = 0.0F;                                                                          *sumwy  = 0.0F;                                                                          *sumwyy = 0.0F;                                                                          for( i = 0; i < num; i++ )                                                               {                                                                                            idx = (int) ( *((type*) (idxdata + i*idxstep)) );                                        w = (float*) (wdata + idx * wstep);                                                      *sumw += *w;                                                                             y = (float*) (ydata + idx * ystep);                                                      wy = (*w) * (*y);                                                                        *sumwy += wy;                                                                            *sumwyy += wy * (*y);                                                                }                                                                                    }                                                                                                                                                                                 for( i = 0; i < num; i++ )                                                               {                                                                                            idx = (int) ( *((type*) (idxdata + i*idxstep)) );                                        curval = (float*) (data + idx * datastep);                                                /* for debug purpose */                                                                 if( i > 0 ) assert( (*prevval) <= (*curval) );                                                                                                                                    wyr  = *sumwy - wyl;                                                                     wr   = *sumw  - wl;                                                                                                                                                               if( wl > 0.0 ) curleft = wyl / wl;                                                       else curleft = 0.0F;                                                                                                                                                              if( wr > 0.0 ) curright = wyr / wr;                                                      else curright = 0.0F;                                                                                                                                                             error                                                                                                                                                                             if( curlerror + currerror < (*lerror) + (*rerror) )                                      {                                                                                            (*lerror) = curlerror;                                                                   (*rerror) = currerror;                                                                   *threshold = *curval;                                                                    if( i > 0 ) {                                                                                *threshold = 0.5F * (*threshold + *prevval);                                         }                                                                                        *left  = curleft;                                                                        *right = curright;                                                                       found = 1;                                                                           }                                                                                                                                                                                 do                                                                                       {                                                                                            wl  += *((float*) (wdata + idx * wstep));                                                wyl += (*((float*) (wdata + idx * wstep)))                                                   * (*((float*) (ydata + idx * ystep)));                                               wyyl += *((float*) (wdata + idx * wstep))                                                    * (*((float*) (ydata + idx * ystep)))                                                    * (*((float*) (ydata + idx * ystep)));                                           }                                                                                        while( (++i) < num &&                                                                        ( *((float*) (data + (idx =                                                                  (int) ( *((type*) (idxdata + i*idxstep))) ) * datastep))                                 == *curval ) );                                                                  --i;                                                                                     prevval = curval;                                                                    } /* for each value */                                                                                                                                                            return found;                                                                        }

这里几个关键的变量:ydata是某个特征所代表的类别,正负样本分别以1,-1进行标注,wdata是正负样本对应的权值,data指的就是valCache的某一行。

程序进来的时候判断sumw是否初始化,没有初始化就进行赋值。由于同一个训练集每个样本都只对应一个ydata和wdata(每个样本对应很多个Haar特征,两者有区别),因此这里的sumw,sumwyy,sumwy都是一个确定的值。提前计算好,在后面的迭代中就不必重复计算。

接下来,根据idxCache中某一行(视迭代次数而定)的index,按从小到大的顺序遍历ValCache中对应行的特征值,也就是不同样本的同一特征值。并将其作为结点,尝试对样本进行划分。curleft和curright分别代表左右子树的类别的加权平均值。然后利用error宏计算左右子树的残差

#define ICV_DEF_FIND_STUMP_THRESHOLD_SQ( suffix, type )                                      ICV_DEF_FIND_STUMP_THRESHOLD( sq_##suffix, type,                                             /* calculate error (sum of squares)          */                                          /* err = sum( w * (y - left(rigt)Val)^2 )    */                                          curlerror = wyyl + curleft * curleft * wl - 2.0F * curleft * wyl;                        currerror = (*sumwyy) - wyyl + curright * curright * wr - 2.0F * curright * wyr;     )

最后的一个do-while循环,就是用来跳过和当前结点相同的特征值。虽然以后续的、相同的值作为结点划分左右子树,残差平方和可能会改变,但是决策树划分的最小单位是特征值的种类,因为在利用决策树进行分类的时候,必须对相同的特征值做出一样的决策(该划入左子树还是该划入右子树)。

总结

HaarTraining的代码是有4,5K行,但是认真学习之后会收获很多机器学习的算法和优秀代码的书写习惯。我会随着学习的深入不断更新自己的源码研究体会,写的虽然不是很详细,但是力求把重点突出来,将自己在阅读代码时碰到的困惑总结出来,给同样学习Training算法的朋友一点点帮助



OpenCV HaarTraining代码解析(二)cvCreateMTStumpClassifier(建立决策树)