首页 > 代码库 > Weka算法Classifier-trees-REPTree源码分析(一)

Weka算法Classifier-trees-REPTree源码分析(一)


一、算法

关于REPTree我实在是没找到什么相关其算法的资料,或许是Weka自创的一个关于决策树的改进,也许是其它某种决策树方法的别名,根据类的注释:Fast decision tree learner. Builds a decision/regression tree using information gain/variance and prunes it using reduced-error pruning (with backfitting).  Only sorts values for numeric attributes once. Missing values are dealt with by splitting the corresponding instances into pieces (i.e. as in C4.5).

我们大概知道和C4.5相比,大概多了backfitting过程,并且数值型排序只进行一次(回想一下J48也就是C4.5算法是每个数据子集都要进行排序),并且缺失值的处理方式和C4.5一样,走不同的path再把结果进行加权。

具体和C4.5的比较将在代码分析之后给出一个总结。


二、buildClassifier

“大名鼎鼎”的分类器训练主入口,几乎每篇分析分类器源码都从这个方法入手。

 public void buildClassifier(Instances data) throws Exception {

    // 首先例行公事看一下给定数据集是否能使用REPTree进行分类,REPTREE基本能支持所有类型
    getCapabilities().testWithFail(data);

    // 把classIndex上没有数据的instance干掉,这些数据既不能用于训练也不能用于backfit
    data = http://www.mamicode.com/new Instances(data);>
	classProbs[(int)inst.classValue()] += inst.weight();//如果是枚举类型,就进行简单的统计
	totalWeight += inst.weight();
      } else {
	classProbs[0] += inst.classValue() * inst.weight();//如果是数值型,就相加,到后面进行取平均的操作
	totalSumSquared += inst.classValue() * inst.classValue() * inst.weight();
	totalWeight += inst.weight();
      }
    }
    m_Tree = new Tree();//建立决策树节点
    double trainVariance = 0;//训练集的方差
    if (data.classAttribute().isNumeric()) {
      trainVariance = m_Tree.
	singleVariance(classProbs[0], totalSumSquared, totalWeight) / totalWeight;
      classProbs[0] /= totalWeight;//这里取平均操作
    }

    // Build tree
    m_Tree.buildTree(sortedIndices, weights, train, totalWeight, classProbs,
		     new Instances(train, 0), m_MinNum, m_MinVarianceProp * 
		     trainVariance, 0, m_MaxDepth);//执行具体树上的构建操作,这参数还真多
    
    // Insert pruning data and perform reduced error pruning
    if (!m_NoPruning) {
      m_Tree.insertHoldOutSet(prune);//传入剪枝数据
      m_Tree.reducedErrorPrune();//进行剪枝
      m_Tree.backfitHoldOutSet();//backfit
    }
  }


(2)Tree.buildTree

Tree是REPTree的一个子对象,训练用参数较多。

 protected void buildTree(int[][][] sortedIndices, double[][][] weights,
			     Instances data, double totalWeight, 
			     double[] classProbs, Instances header,
			     double minNum, double minVariance,
			     int depth, int maxDepth) 
      throws Exception {
      //第一个参数是按属性排好序的下标,第二个是这些下标对应的weight,第三个是训练数据
<span style="white-space:pre">	</span>//第四个是总权重,第五个是各类的分布,第六个是表头,第七个是每个节点最小instance数量
<span style="white-space:pre">	</span>//第八个是最小的方差 ,第九个是当前深度(0 base),第十个是最大深度
      
      m_Info = header;//首先存下表头
      if (data.classAttribute().isNumeric()) {
        m_HoldOutDist = new double[2];//这个数组用于存放分布
      } else {
        m_HoldOutDist = new double[data.numClasses()];
      }
	
      // 看看是否有有效数据
      int helpIndex = 0;
      if (data.classIndex() == 0) {
	helpIndex = 1;//传入的数据至少两列,因为一列的话上层就用m_zerO模型了,这个if是为了保证helpIndex对应的肯定是训练数据
      }
      if (sortedIndices[0][helpIndex].length == 0) {//如果没数据,就直接反悔了
	if (data.classAttribute().isNumeric()) {
	  m_Distribution = new double[2];//为什么是二维的?第一维存放方差,第二维存放weight,基于约定的编程方式
	} else {
	  m_Distribution = new double[data.numClasses()];
	}
	m_ClassProbs = null;
        sortedIndices[0] = null;
        weights[0] = null;
	return;
      }
      
      double priorVar = 0;//存放class的方差(其实是方差*num),只有class是数值才有意义,下面就是计算方差的过程。
      if (data.classAttribute().isNumeric()) {

	// 每个sortedIndices[0][i]里面的都是一个Instances的index不同排列而已,使用helpIndex只是为了保证别对应到classIndex上
	double totalSum = 0, totalSumSquared = 0, totalSumOfWeights = 0; 
	for (int i = 0; i < sortedIndices[0][helpIndex].length; i++) {
	  Instance inst = data.instance(sortedIndices[0][helpIndex][i]);
	  totalSum += inst.classValue() * weights[0][helpIndex][i];
	  totalSumSquared += 
	    inst.classValue() * inst.classValue() * weights[0][helpIndex][i];
	  totalSumOfWeights += weights[0][helpIndex][i];
	}
	priorVar = singleVariance(totalSum, totalSumSquared, 
				  totalSumOfWeights);
      }

      //把分布拷贝一下
      m_ClassProbs = new double[classProbs.length];
      System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);
      if ((//退出条件有4个
<span style="white-space:pre">	</span>//第一个是instances里面的totalweight总量(可以理解成里面的instance数量,因为weight默认都是1)小于两倍的minNum,minNum默认是2.
<span style="white-space:pre">	</span>totalWeight < (2 * minNum)) ||

	  // 如果是枚举类型,并且都在一类中
	  (data.classAttribute().isNominal() &&
	   Utils.eq(m_ClassProbs[Utils.maxIndex(m_ClassProbs)],
		    Utils.sum(m_ClassProbs))) ||

	  // 数值型则比较方差是否小于minVariance,这个minVariance默认是原始方差的0.001,从上层代码可以得知
	  (data.classAttribute().isNumeric() && 
	   ((priorVar / totalWeight) < minVariance)) ||

	  // 达到最大深度
	  ((m_MaxDepth >= 0) && (depth >= maxDepth))) {

	// 设置成叶子
	m_Attribute = -1;
	if (data.classAttribute().isNominal()) {

	  // 设置枚举类型的分布
	  m_Distribution = new double[m_ClassProbs.length];
	  for (int i = 0; i < m_ClassProbs.length; i++) {
	    m_Distribution[i] = m_ClassProbs[i];
	  }
	  Utils.normalize(m_ClassProbs);
	} else {

	  // 设置数值类型的“分布”
	  m_Distribution = new double[2];
	  m_Distribution[0] = priorVar;
	  m_Distribution[1] = totalWeight;
	}
        sortedIndices[0] = null;
        weights[0] = null;
	return;
      }

      // 下面是寻找分裂点的过程
      double[] vals = new double[data.numAttributes()];//每个属性产生的信息增益
      double[][][] dists = new double[data.numAttributes()][0][0];//每个属性下每个类的分布
      double[][] props = new double[data.numAttributes()][0];//每个属性下class的概率,也就是根据上面这个数组的分布求概率
      double[][] totalSubsetWeights = new double[data.numAttributes()][0];//每个属性下每个subset的数量
      double[] splits = new double[data.numAttributes()];//每个属性的分裂点,如果是枚举型则为NaN
      if (data.classAttribute().isNominal()) { 

	// 首先来看classAttribute是枚举类型的情况
	for (int i = 0; i < data.numAttributes(); i++) {
	  if (i != data.classIndex()) {
	    splits[i] = distribution(props, dists, i, sortedIndices[0][i], 
				     weights[0][i], totalSubsetWeights, data);//得到分裂点、概率和分布
	    vals[i] = gain(dists[i], priorVal(dists[i]));//得到信息增益
	  }
	}
      } else {

	// 如果是数值类型则不算信息增益(为什么数值类型不算增益?只有因为枚举型才算的出信息熵)(吐个槽:话说这个if-else为啥不放在循环里面??)
	for (int i = 0; i < data.numAttributes(); i++) {
	  if (i != data.classIndex()) {
	    splits[i] = 
	      numericDistribution(props, dists, i, sortedIndices[0][i], 
				  weights[0][i], totalSubsetWeights, data, 
				  vals);
	  }
	}
      }

      // 选出信息增益最大的作为分裂属性
      m_Attribute = Utils.maxIndex(vals);
      int numAttVals = dists[m_Attribute].length;

      // 每个subset都要多于minNum,这样才算一个有效subset
      int count = 0;
      for (int i = 0; i < numAttVals; i++) {
	if (totalSubsetWeights[m_Attribute][i] >= minNum) {
	  count++;
	}
	if (count > 1) {
	  break;
	}
      }

      // 至少存在2个有效subset,才算是一个有效的split
      if (Utils.gr(vals[m_Attribute], 0) && (count > 1)) {      

        // Set split point, proportions, and temp arrays
	m_SplitPoint = splits[m_Attribute];
	m_Prop = props[m_Attribute];
        double[][] attSubsetDists = dists[m_Attribute];
        double[] attTotalSubsetWeights = totalSubsetWeights[m_Attribute];

        // 释放内存
        vals = null;
        dists = null;
        props = null;
        totalSubsetWeights = null;
        splits = null;

	// 得到subSet的有序index
	int[][][][] subsetIndices = 
	  new int[numAttVals][1][data.numAttributes()][0];
	double[][][][] subsetWeights = 
	  new double[numAttVals][1][data.numAttributes()][0];
	splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitPoint, 
		  sortedIndices[0], weights[0], data);

        // 释放内存
        sortedIndices[0] = null;
        weights[0] = null;

        //释放内存
	m_Successors = new Tree[numAttVals];
	for (int i = 0; i < numAttVals; i++) {
	  m_Successors[i] = new Tree();//构建孩子节点
	  m_Successors[i].
	    buildTree(subsetIndices[i], subsetWeights[i], 
		      data, attTotalSubsetWeights[i],
		      attSubsetDists[i], header, minNum, 
		      minVariance, depth + 1, maxDepth);

          // 还是释放内存
          attSubsetDists[i] = null;
	}
      } else {
      
	// 如果不存在2个有效的subset,就直接当叶子节点了
	m_Attribute = -1;
        sortedIndices[0] = null;
        weights[0] = null;
      }

      // 构建attribute用于之后的分类过程(当然这是在没有prune和backfit情况下用的)
      if (data.classAttribute().isNominal()) {
	m_Distribution = new double[m_ClassProbs.length];
	for (int i = 0; i < m_ClassProbs.length; i++) {
	    m_Distribution[i] = m_ClassProbs[i];
	}
	Utils.normalize(m_ClassProbs);
      } else {
	m_Distribution = new double[2];
	m_Distribution[0] = priorVar;
	m_Distribution[1] = totalWeight;
      }
    }



Weka算法Classifier-trees-REPTree源码分析(一)