首页 > 代码库 > Weka算法Classifier-tree-J48源码分析(三)ModelSelection
Weka算法Classifier-tree-J48源码分析(三)ModelSelection
ModelSelection主要是用于选择合适的列对数据集进行分割,结合上一篇J48的主流程,发现用到的ModelSelection有 C45ModelSelection以及BinC45ModelSelection,先来分析C45ModelSelection。
一、C45ModelSelection
首先作为一个ModelSelection接口,实现的主要方法有两个,分别是selectModel(Instances)和selectionModel(Instances,Instances)。C45ModelSelection的后一个方法如下:
public final ClassifierSplitModel selectModel(Instances train, Instances test) { return selectModel(train); }可以看到就是忽略了test测试集直接调用selectModel方法而已,因此主要分词selectModel方法。
先放出整段代码,然后对该段代码进行分析:
public final ClassifierSplitModel selectModel(Instances data){ double minResult; double currentResult; C45Split [] currentModel; C45Split bestModel = null; NoSplit noSplitModel = null; double averageInfoGain = 0; int validModels = 0; boolean multiVal = true; Distribution checkDistribution; Attribute attribute; double sumOfWeights; int i; try{ // Check if all Instances belong to one class or if not // enough Instances to split. checkDistribution = new Distribution(data); noSplitModel = new NoSplit(checkDistribution); if (Utils.sm(checkDistribution.total(),2*m_minNoObj) || Utils.eq(checkDistribution.total(), checkDistribution.perClass(checkDistribution.maxClass()))) return noSplitModel; // Check if all attributes are nominal and have a // lot of values. if (m_allData != null) { Enumeration enu = data.enumerateAttributes(); while (enu.hasMoreElements()) { attribute = (Attribute) enu.nextElement(); if ((attribute.isNumeric()) || (Utils.sm((double)attribute.numValues(), (0.3*(double)m_allData.numInstances())))){ multiVal = false; break; } } } currentModel = new C45Split[data.numAttributes()]; sumOfWeights = data.sumOfWeights(); // For each attribute. for (i = 0; i < data.numAttributes(); i++){ // Apart from class attribute. if (i != (data).classIndex()){ // Get models for current attribute. currentModel[i] = new C45Split(i,m_minNoObj,sumOfWeights); currentModel[i].buildClassifier(data); // Check if useful split for current attribute // exists and check for enumerated attributes with // a lot of values. if (currentModel[i].checkModel()) if (m_allData != null) { if ((data.attribute(i).isNumeric()) || (multiVal || Utils.sm((double)data.attribute(i).numValues(), (0.3*(double)m_allData.numInstances())))){ averageInfoGain = averageInfoGain+currentModel[i].infoGain(); validModels++; } } else { averageInfoGain = averageInfoGain+currentModel[i].infoGain(); validModels++; } }else currentModel[i] = null; } // Check if any useful split was found. if (validModels == 0) return noSplitModel; averageInfoGain = averageInfoGain/(double)validModels; // Find "best" attribute to split on. minResult = 0; for (i=0;i<data.numAttributes();i++){ if ((i != (data).classIndex()) && (currentModel[i].checkModel())) // Use 1E-3 here to get a closer approximation to the original // implementation. if ((currentModel[i].infoGain() >= (averageInfoGain-1E-3)) && Utils.gr(currentModel[i].gainRatio(),minResult)){ bestModel = currentModel[i]; minResult = currentModel[i].gainRatio(); } } // Check if useful split was found. if (Utils.eq(minResult,0)) return noSplitModel; // Add all Instances with unknown values for the corresponding // attribute to the distribution for the model, so that // the complete distribution is stored with the model. bestModel.distribution(). addInstWithUnknown(data,bestModel.attIndex()); // Set the split point analogue to C45 if attribute numeric. if (m_allData != null) bestModel.setSplitPoint(m_allData); return bestModel; }catch(Exception e){ e.printStackTrace(); } return null; }第一部分,主要是对局部变量的一些定义。
double minResult;//最小的信息增益率 double currentResult;//当前信息增益率 C45Split [] currentModel;//存放所有未分类属性产生的模型 C45Split bestModel = null;//目前为止的最好模型 NoSplit noSplitModel = null;//代表不用分的模型 double averageInfoGain = 0;//各模型(currentModel)的平均信息增益 int validModels = 0;//是否存在有效模型 boolean multiVal = true;//是否多值 Distribution checkDistribution;//训练数据集的分布 Attribute attribute;//属性列集合 double sumOfWeights;//训练数据集的weight的和 int i;//循环变量
第二部分,递归出口。
checkDistribution = new Distribution(data); noSplitModel = new NoSplit(checkDistribution); if (Utils.sm(checkDistribution.total(),2*m_minNoObj) || Utils.eq(checkDistribution.total(), checkDistribution.perClass(checkDistribution.maxClass()))) return noSplitModel;可以看到,如果当前数据集数量小于2*m_minNoObj(这个值默认是2),或者当前数据集已经全在同一个分类中,就返回noSplitModel代表不用分,这就是整个C45分类树节点停止分裂的条件。
第三部分,判断是否是多值:
if (m_allData != null) { Enumeration enu = data.enumerateAttributes(); while (enu.hasMoreElements()) { attribute = (Attribute) enu.nextElement(); if ((attribute.isNumeric()) || (Utils.sm((double)attribute.numValues(), (0.3*(double)m_allData.numInstances())))){ multiVal = false; break; } } }如果属性中,任意一列是数值型,或者其取值的数量小于训练集数量*0.3,则不是多值,否则按多值处理。是否是多值影响到后面某些逻辑。
第四部分,对于每一列属性构造Spliter。
for (i = 0; i < data.numAttributes(); i++){ // Apart from class attribute. if (i != (data).classIndex()){ // Get models for current attribute. currentModel[i] = new C45Split(i,m_minNoObj,sumOfWeights); currentModel[i].buildClassifier(data); // Check if useful split for current attribute // exists and check for enumerated attributes with // a lot of values. if (currentModel[i].checkModel()) if (m_allData != null) { if ((data.attribute(i).isNumeric()) || (multiVal || Utils.sm((double)data.attribute(i).numValues(), (0.3*(double)m_allData.numInstances())))){ averageInfoGain = averageInfoGain+currentModel[i].infoGain(); validModels++; } } else { averageInfoGain = averageInfoGain+currentModel[i].infoGain(); validModels++; } }else currentModel[i] = null; }
对于每一列属性,如果不是存放分类的值得话,则构造C45Split对象,在该对象上进行分类,然后算出信息增益,相加到averageInfoGain上。对于C45Split的构造,稍后再看。
第五部分,选出最优模型。
if (validModels == 0) return noSplitModel; averageInfoGain = averageInfoGain/(double)validModels; // Find "best" attribute to split on. minResult = 0; for (i=0;i<data.numAttributes();i++){ if ((i != (data).classIndex()) && (currentModel[i].checkModel())) // Use 1E-3 here to get a closer approximation to the original // implementation. if ((currentModel[i].infoGain() >= (averageInfoGain-1E-3)) && Utils.gr(currentModel[i].gainRatio(),minResult)){ bestModel = currentModel[i]; minResult = currentModel[i].gainRatio(); }
如果存在有效模型,则选出有效模型。注意这个选出最优模型的逻辑,并不是单纯的选出gainRatio最大的,而是在基础上必须还要大于平均信息增益,这也是和传统的c45算法不一样的一点。
从上述过程来看,Weka在实现C45的时候做了一个小的变动,并没有从“还没有使用的”属性列中找出最合理的列最为分割属性,而是在“所有的列”中找出最合理的列作为分割属性,虽然这二者在结果上肯定是等价的(之前是有过的属性不和能有很好的信息增益率),但效率上个人对Weka的做法持保留意见。
二、C45Spliter
在ModelSelection中真正根据属性对训练集进行分割、计算信息增益和信息增益率的是C45Spliter,首先也从其buildClassifier方法入手进行分析。
public void buildClassifier(Instances trainInstances) throws Exception { // Initialize the remaining instance variables. m_numSubsets = 0; m_splitPoint = Double.MAX_VALUE; m_infoGain = 0; m_gainRatio = 0; // Different treatment for enumerated and numeric // attributes. if (trainInstances.attribute(m_attIndex).isNominal()) { m_complexityIndex = trainInstances.attribute(m_attIndex).numValues(); m_index = m_complexityIndex; handleEnumeratedAttribute(trainInstances); }else{ m_complexityIndex = 2; m_index = 0; trainInstances.sort(trainInstances.attribute(m_attIndex)); handleNumericAttribute(trainInstances); } }可以看到,对于枚举型和数值型的属性是分开处理的,枚举型调用handlEnumeratedAttribute,数值型调用handleNumericAttribute,值得注意的是,在处理数值型之前,按照相应列进行排序,同时设置m_complexityIndex也就是期望分裂的节点数设定为2。
首先来看枚举类型是如何处理的。
private void handleEnumeratedAttribute(Instances trainInstances) throws Exception { Instance instance; m_distribution = new Distribution(m_complexityIndex, trainInstances.numClasses()); // Only Instances with known values are relevant. Enumeration enu = trainInstances.enumerateInstances(); while (enu.hasMoreElements()) { instance = (Instance) enu.nextElement(); if (!instance.isMissing(m_attIndex)) m_distribution.add((int)instance.value(m_attIndex),instance); } // Check if minimum number of Instances in at least two // subsets. if (m_distribution.check(m_minNoObj)) { m_numSubsets = m_complexityIndex; m_infoGain = infoGainCrit. splitCritValue(m_distribution,m_sumOfWeights); m_gainRatio = gainRatioCrit.splitCritValue(m_distribution,m_sumOfWeights, m_infoGain); } }大概流程是新建一个分布,遍历所有instance,如果该instance对应的分裂的属性不为空的话,则放到不同的bag里,之后检查一下这个分布是否满足要求,要求就是最多允许有一个bag里的数据数量小于m_minNoObj,如果通过检查,就设置subset的数量,计算信息增益和信息增益率,否则subset默认会是0,上层调用checkModel就会返回false代表这是一个无效模型。
接下来看数值型是如何处理的:
private void handleNumericAttribute(Instances trainInstances) throws Exception { int firstMiss;//最后一个有效instance的下标 int next = 1;//下一个instance的index int last = 0;//当前instance的index int splitIndex = -1;//分裂点 double currentInfoGain;//当前信息增益 double defaultEnt;//分割之前的信息熵 double minSplit; Instance instance; int i;
//首先新建一个分布,数值型默认处理为2维分布,也就可以理解为小于某个值放到一个Bag里,其余的放到另外一个Bag里
m_distribution = new Distribution(2,trainInstances.numClasses()); Enumeration enu = trainInstances.enumerateInstances(); i = 0;
<pre name="code" class="cpp">//注意instances传入的时候是排好序的,这个排序保证了missingValue放在最后面,所以读到了missingValue其之后肯定都是miss//ingValue,换言之,firstMiss在循环之后代表了最后一个有效的instance的下标。while (enu.hasMoreElements()) { instance = (Instance) enu.nextElement(); if (instance.isMissing(m_attIndex))break; m_distribution.add(1,instance); i++; } firstMiss = i;//循环结束后,m_distribution里放入了所有的有效instance,并全放入了bag1里。
//minSplit是最后分类好每个Bag里最小的数据的量,也就是0.1*每个类的均值。 minSplit = 0.1*(m_distribution.total())/ ((double)trainInstances.numClasses()); if (Utils.smOrEq(minSplit,m_minNoObj)) minSplit = m_minNoObj; else if (Utils.gr(minSplit,25)) minSplit = 25; //如果有效数据总量不到2*minSplit,换言之无论怎么分均不能保证2个bag里的数量大于minSplit,就直接返回。 if (Utils.sm((double)firstMiss,2*minSplit)) return; //defaultEnt代表旧的信息熵,也就是对该属性进行分类之前,Indexclass对应的信息熵。 defaultEnt = infoGainCrit.oldEnt(m_distribution); while (next < firstMiss) { if (trainInstances.instance(next-1).value(m_attIndex)+1e-5 < trainInstances.instance(next).value(m_attIndex)) { <pre name="code" class="cpp">//Instances里的记录是升序排列的,加上这个条件默认把值相差很小的Instance就当做同一个instance处理了
//last代表当前,next代表下一个,默认next=1,last=0,所以shiftRange可以理解成把当前记录从bag1移动到bag0中
<span style="font-family: Arial, Helvetica, sans-serif;">//注意一开始初始化时候所有的都是在bag1里面的。 </span>m_distribution.shiftRange(1,0,trainInstances,last,next);if (Utils.grOrEq(m_distribution.perBag(0),minSplit) && //如果两个bag都满足最小数据集的数量minSplit Utils.grOrEq(m_distribution.perBag(1),minSplit)) { currentInfoGain = infoGainCrit. splitCritValue(m_distribution,m_sumOfWeights, //算一下信息增益 defaultEnt);
if (Utils.gr(currentInfoGain,m_infoGain)) { m_infoGain = currentInfoGain;//如果信息增益比当前最大的要大,则替换当前最大的值,并记录splitIndex splitIndex = next-1; } m_index++; } last = next; } next++; } if (m_index == 0) return; //执行到这里说明没找到一个合适的分裂点,直接返回。 // 计算最佳信息增益 m_infoGain = m_infoGain-(Utils.log2(m_index)/m_sumOfWeights); if (Utils.smOrEq(m_infoGain,0)) return; //如果信息增益是0也说明没找到合适的分裂点,直接返回。 //剩下的就是根据分裂点进行属性的划分。 m_numSubsets = 2; m_splitPoint = (trainInstances.instance(splitIndex+1).value(m_attIndex)+ trainInstances.instance(splitIndex).value(m_attIndex))/2; // In case we have a numerical precision problem we need to choose the // smaller value if (m_splitPoint == trainInstances.instance(splitIndex + 1).value(m_attIndex)) { m_splitPoint = trainInstances.instance(splitIndex).value(m_attIndex); } // Restore distributioN for best split. m_distribution = new Distribution(2,trainInstances.numClasses()); m_distribution.addRange(0,trainInstances,0,splitIndex+1); m_distribution.addRange(1,trainInstances,splitIndex+1,firstMiss); // Compute modified gain ratio for best split. m_gainRatio = gainRatioCrit. splitCritValue(m_distribution,m_sumOfWeights, m_infoGain); }这个函数有点复杂,具体逻辑也写到代码注释里了。
三、BinC45ModelSelection
该函数只负责生成二元分类树的模型,selectModel方法和C45ModelSelection几乎一样,不在多说,不同点在于其使用BinC45Spliter而不是C45Spliter。
四、BinC45Spliter
handleNumericAttribute对于数值类型的属性处理和C45Spliter完全一样。下面只分析一下handleEnumeratedAttribute。
private void handleEnumeratedAttribute(Instances trainInstances) throws Exception { Distribution newDistribution,secondDistribution; int numAttValues; double currIG,currGR; Instance instance; int i; numAttValues = trainInstances.attribute(m_attIndex).numValues(); newDistribution = new Distribution(numAttValues, trainInstances.numClasses()); // Only Instances with known values are relevant. Enumeration enu = trainInstances.enumerateInstances(); while (enu.hasMoreElements()) { instance = (Instance) enu.nextElement(); if (!instance.isMissing(m_attIndex)) newDistribution.add((int)instance.value(m_attIndex),instance); } m_distribution = newDistribution; // For all values for (i = 0; i < numAttValues; i++){ if (Utils.grOrEq(newDistribution.perBag(i),m_minNoObj)){ secondDistribution = new Distribution(newDistribution,i); // Check if minimum number of Instances in the two // subsets. if (secondDistribution.check(m_minNoObj)){ m_numSubsets = 2; currIG = m_infoGainCrit.splitCritValue(secondDistribution, m_sumOfWeights); currGR = m_gainRatioCrit.splitCritValue(secondDistribution, m_sumOfWeights, currIG); if ((i == 0) || Utils.gr(currGR,m_gainRatio)){ m_gainRatio = currGR; m_infoGain = currIG; m_splitPoint = (double)i; m_distribution = secondDistribution; } } } }可以看出,上一段代码根据该属性的不同的取值,在已有分布基础上,建立一个新的分布secondeDistribution,
secondDistribution = new Distribution(newDistribution,i);该分布包含两列,属性下标为i的,其余的,在这个分布的基础上计算信息增益和信息增益率,并选出最优的。
换句话说,离散值分类的二元化处理就是选出其中一列当做一个branch,其余的当做另外一个branch。虽然从结构上来讲这肯定不是最优的选择,但简单易用就够了。
到这里基本分析完了J48的两个ModelSelection,下一篇文章将对classifierInstance过程进行分析,并给出一个简单的总结。
Weka算法Classifier-tree-J48源码分析(三)ModelSelection