首页 > 代码库 > Weka算法Classifier-trees-REPTree源码分析(一)
Weka算法Classifier-trees-REPTree源码分析(一)
一、算法
关于REPTree我实在是没找到什么相关其算法的资料,或许是Weka自创的一个关于决策树的改进,也许是其它某种决策树方法的别名,根据类的注释:Fast decision tree learner. Builds a decision/regression tree using information gain/variance and prunes it using reduced-error pruning (with backfitting). Only sorts values for numeric attributes once. Missing values are dealt with by splitting the corresponding instances into pieces (i.e. as in C4.5).
我们大概知道和C4.5相比,大概多了backfitting过程,并且数值型排序只进行一次(回想一下J48也就是C4.5算法是每个数据子集都要进行排序),并且缺失值的处理方式和C4.5一样,走不同的path再把结果进行加权。
具体和C4.5的比较将在代码分析之后给出一个总结。
二、buildClassifier
“大名鼎鼎”的分类器训练主入口,几乎每篇分析分类器源码都从这个方法入手。
public void buildClassifier(Instances data) throws Exception { // 首先例行公事看一下给定数据集是否能使用REPTree进行分类,REPTREE基本能支持所有类型 getCapabilities().testWithFail(data); // 把classIndex上没有数据的instance干掉,这些数据既不能用于训练也不能用于backfit data = http://www.mamicode.com/new Instances(data);>classProbs[(int)inst.classValue()] += inst.weight();//如果是枚举类型,就进行简单的统计 totalWeight += inst.weight(); } else { classProbs[0] += inst.classValue() * inst.weight();//如果是数值型,就相加,到后面进行取平均的操作 totalSumSquared += inst.classValue() * inst.classValue() * inst.weight(); totalWeight += inst.weight(); } } m_Tree = new Tree();//建立决策树节点 double trainVariance = 0;//训练集的方差 if (data.classAttribute().isNumeric()) { trainVariance = m_Tree. singleVariance(classProbs[0], totalSumSquared, totalWeight) / totalWeight; classProbs[0] /= totalWeight;//这里取平均操作 } // Build tree m_Tree.buildTree(sortedIndices, weights, train, totalWeight, classProbs, new Instances(train, 0), m_MinNum, m_MinVarianceProp * trainVariance, 0, m_MaxDepth);//执行具体树上的构建操作,这参数还真多 // Insert pruning data and perform reduced error pruning if (!m_NoPruning) { m_Tree.insertHoldOutSet(prune);//传入剪枝数据 m_Tree.reducedErrorPrune();//进行剪枝 m_Tree.backfitHoldOutSet();//backfit } }
(2)Tree.buildTree
Tree是REPTree的一个子对象,训练用参数较多。
protected void buildTree(int[][][] sortedIndices, double[][][] weights, Instances data, double totalWeight, double[] classProbs, Instances header, double minNum, double minVariance, int depth, int maxDepth) throws Exception { //第一个参数是按属性排好序的下标,第二个是这些下标对应的weight,第三个是训练数据<span style="white-space:pre"> </span>//第四个是总权重,第五个是各类的分布,第六个是表头,第七个是每个节点最小instance数量<span style="white-space:pre"> </span>//第八个是最小的方差 ,第九个是当前深度(0 base),第十个是最大深度m_Info = header;//首先存下表头 if (data.classAttribute().isNumeric()) { m_HoldOutDist = new double[2];//这个数组用于存放分布 } else { m_HoldOutDist = new double[data.numClasses()]; } // 看看是否有有效数据 int helpIndex = 0; if (data.classIndex() == 0) { helpIndex = 1;//传入的数据至少两列,因为一列的话上层就用m_zerO模型了,这个if是为了保证helpIndex对应的肯定是训练数据 } if (sortedIndices[0][helpIndex].length == 0) {//如果没数据,就直接反悔了 if (data.classAttribute().isNumeric()) { m_Distribution = new double[2];//为什么是二维的?第一维存放方差,第二维存放weight,基于约定的编程方式 } else { m_Distribution = new double[data.numClasses()]; } m_ClassProbs = null; sortedIndices[0] = null; weights[0] = null; return; } double priorVar = 0;//存放class的方差(其实是方差*num),只有class是数值才有意义,下面就是计算方差的过程。 if (data.classAttribute().isNumeric()) { // 每个sortedIndices[0][i]里面的都是一个Instances的index不同排列而已,使用helpIndex只是为了保证别对应到classIndex上 double totalSum = 0, totalSumSquared = 0, totalSumOfWeights = 0; for (int i = 0; i < sortedIndices[0][helpIndex].length; i++) { Instance inst = data.instance(sortedIndices[0][helpIndex][i]); totalSum += inst.classValue() * weights[0][helpIndex][i]; totalSumSquared += inst.classValue() * inst.classValue() * weights[0][helpIndex][i]; totalSumOfWeights += weights[0][helpIndex][i]; } priorVar = singleVariance(totalSum, totalSumSquared, totalSumOfWeights); } //把分布拷贝一下 m_ClassProbs = new double[classProbs.length]; System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length); if ((//退出条件有4个<span style="white-space:pre"> </span>//第一个是instances里面的totalweight总量(可以理解成里面的instance数量,因为weight默认都是1)小于两倍的minNum,minNum默认是2.<span style="white-space:pre"> </span>totalWeight < (2 * minNum)) || // 如果是枚举类型,并且都在一类中 (data.classAttribute().isNominal() && Utils.eq(m_ClassProbs[Utils.maxIndex(m_ClassProbs)], Utils.sum(m_ClassProbs))) || // 数值型则比较方差是否小于minVariance,这个minVariance默认是原始方差的0.001,从上层代码可以得知 (data.classAttribute().isNumeric() && ((priorVar / totalWeight) < minVariance)) || // 达到最大深度 ((m_MaxDepth >= 0) && (depth >= maxDepth))) { // 设置成叶子 m_Attribute = -1; if (data.classAttribute().isNominal()) { // 设置枚举类型的分布 m_Distribution = new double[m_ClassProbs.length]; for (int i = 0; i < m_ClassProbs.length; i++) { m_Distribution[i] = m_ClassProbs[i]; } Utils.normalize(m_ClassProbs); } else { // 设置数值类型的“分布” m_Distribution = new double[2]; m_Distribution[0] = priorVar; m_Distribution[1] = totalWeight; } sortedIndices[0] = null; weights[0] = null; return; } // 下面是寻找分裂点的过程 double[] vals = new double[data.numAttributes()];//每个属性产生的信息增益 double[][][] dists = new double[data.numAttributes()][0][0];//每个属性下每个类的分布 double[][] props = new double[data.numAttributes()][0];//每个属性下class的概率,也就是根据上面这个数组的分布求概率 double[][] totalSubsetWeights = new double[data.numAttributes()][0];//每个属性下每个subset的数量 double[] splits = new double[data.numAttributes()];//每个属性的分裂点,如果是枚举型则为NaN if (data.classAttribute().isNominal()) { // 首先来看classAttribute是枚举类型的情况 for (int i = 0; i < data.numAttributes(); i++) { if (i != data.classIndex()) { splits[i] = distribution(props, dists, i, sortedIndices[0][i], weights[0][i], totalSubsetWeights, data);//得到分裂点、概率和分布 vals[i] = gain(dists[i], priorVal(dists[i]));//得到信息增益 } } } else { // 如果是数值类型则不算信息增益(为什么数值类型不算增益?只有因为枚举型才算的出信息熵)(吐个槽:话说这个if-else为啥不放在循环里面??) for (int i = 0; i < data.numAttributes(); i++) { if (i != data.classIndex()) { splits[i] = numericDistribution(props, dists, i, sortedIndices[0][i], weights[0][i], totalSubsetWeights, data, vals); } } } // 选出信息增益最大的作为分裂属性 m_Attribute = Utils.maxIndex(vals); int numAttVals = dists[m_Attribute].length; // 每个subset都要多于minNum,这样才算一个有效subset int count = 0; for (int i = 0; i < numAttVals; i++) { if (totalSubsetWeights[m_Attribute][i] >= minNum) { count++; } if (count > 1) { break; } } // 至少存在2个有效subset,才算是一个有效的split if (Utils.gr(vals[m_Attribute], 0) && (count > 1)) { // Set split point, proportions, and temp arrays m_SplitPoint = splits[m_Attribute]; m_Prop = props[m_Attribute]; double[][] attSubsetDists = dists[m_Attribute]; double[] attTotalSubsetWeights = totalSubsetWeights[m_Attribute]; // 释放内存 vals = null; dists = null; props = null; totalSubsetWeights = null; splits = null; // 得到subSet的有序index int[][][][] subsetIndices = new int[numAttVals][1][data.numAttributes()][0]; double[][][][] subsetWeights = new double[numAttVals][1][data.numAttributes()][0]; splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitPoint, sortedIndices[0], weights[0], data); // 释放内存 sortedIndices[0] = null; weights[0] = null; //释放内存 m_Successors = new Tree[numAttVals]; for (int i = 0; i < numAttVals; i++) { m_Successors[i] = new Tree();//构建孩子节点 m_Successors[i]. buildTree(subsetIndices[i], subsetWeights[i], data, attTotalSubsetWeights[i], attSubsetDists[i], header, minNum, minVariance, depth + 1, maxDepth); // 还是释放内存 attSubsetDists[i] = null; } } else { // 如果不存在2个有效的subset,就直接当叶子节点了 m_Attribute = -1; sortedIndices[0] = null; weights[0] = null; } // 构建attribute用于之后的分类过程(当然这是在没有prune和backfit情况下用的) if (data.classAttribute().isNominal()) { m_Distribution = new double[m_ClassProbs.length]; for (int i = 0; i < m_ClassProbs.length; i++) { m_Distribution[i] = m_ClassProbs[i]; } Utils.normalize(m_ClassProbs); } else { m_Distribution = new double[2]; m_Distribution[0] = priorVar; m_Distribution[1] = totalWeight; } }Weka算法Classifier-trees-REPTree源码分析(一)
声明:以上内容来自用户投稿及互联网公开渠道收集整理发布,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任,若内容有误或涉及侵权可进行投诉: 投诉/举报 工作人员会在5个工作日内联系你,一经查实,本站将立刻删除涉嫌侵权内容。