首页 > 代码库 > Weka算法Classifier-tree-J48源码分析(二)ClassifierTree
Weka算法Classifier-tree-J48源码分析(二)ClassifierTree
一、问题
主要带着四个问题去研究J48的实现。
1、如何控制分类树的精度。
2、如何处理缺失的值(MissingValue)
3、如何对连续值进行离散化。
4、如何进行分类树的剪枝。
二、BuildClassifier
每一个分类器都会实现这个方法,传入一个Instances对象,在这个对象基础上进行来构建分类树。核心代码如下:
public void buildClassifier(Instances instances) throws Exception { ModelSelection modSelection; if (m_binarySplits) modSelection = new BinC45ModelSelection(m_minNumObj, instances); else modSelection = new C45ModelSelection(m_minNumObj, instances); if (!m_reducedErrorPruning) m_root = new C45PruneableClassifierTree(modSelection, !m_unpruned, m_CF, m_subtreeRaising, !m_noCleanup); else m_root = new PruneableClassifierTree(modSelection, !m_unpruned, m_numFolds, !m_noCleanup, m_Seed); m_root.buildClassifier(instances); if (m_binarySplits) { ((BinC45ModelSelection)modSelection).cleanup(); } else { ((C45ModelSelection)modSelection).cleanup(); } }可以看到这段代码逻辑非常清楚,首先根据是否是一个二分树(即每个节点只有是否两种选择)来构造一个ModelSelection,随后根据是否有m_reduceErrorPruning标志来构造相应的ClassifierTree,在这个tree上真正的构建模型,最后清理数据(主要是做释放指针的工作,防止Tree持有Instances指针导致GC不能在上层调用者想释放Instances的时候进行释放)。
三、C45PruneableClassifierTree
(1)该类也实现了BuildCClassifier方法来构建分类器,先看一下这个方法的主逻辑,代码如下:
public void buildClassifier(Instances data) throws Exception { // can classifier tree handle the data? getCapabilities().testWithFail(data); // remove instances with missing class data = http://www.mamicode.com/new Instances(data);>首先testWithFail是检测一下传入的data是否能用该分类器进行分类,比如C45只能对要分类的属性的取值是离散值的Instances进行分类,这个test就是检测诸如此类的逻辑。接着清理一下instances里面的无效行(相应分类属性为空的行)。
在此数据上调用buildTree进行构建分类树。
调用collapse()进行树的“坍塌”(这里我不太知道学名应该怎么翻译)
如果有需要,则进行prune()剪枝。
最后清理数据。
(2)按照这个顺序首先来看buildTree函数
public void buildTree(Instances data, boolean keepData) throws Exception { Instances [] localInstances; if (keepData) { m_train = data; } m_test = null; m_isLeaf = false; m_isEmpty = false; m_sons = null; m_localModel = m_toSelectModel.selectModel(data); if (m_localModel.numSubsets() > 1) { localInstances = m_localModel.split(data); data = http://www.mamicode.com/null;>该函数逻辑也比较简单(怎么都比较简单?!),首先根据传入参数来判断是否应该持有数据。然后根据m_toSelectModel来选择一个模型并把传入的数据集按相应的规则分成不同的subSet,这个selectModel是构造函数传入的,参见刚才描述的主流程。这一步如果对应上篇博客的算法描述,得到的subSet就是第10行的dv。
接着判断subSet的数量,如果只有一个,那么就是一个叶子节点,什么都不用做就返回了。
否则根据localModel将data分成不同的subInstances,接着为每一个subInstances建立新的ClassifierTree节点作为自己的孩子节点,并调用getNewTree函数来为每一个subInstances构造新的tree。
(3)采用DFS的方式接着去看一下getNewTree的逻辑
protected ClassifierTree getNewTree(Instances data) throws Exception { ClassifierTree newTree = new ClassifierTree(m_toSelectModel); newTree.buildTree(data, false); return newTree; }很简单,就是一个递归调用。(4)重新回到C45PruneableClassifierTree.buildClassifier方法,来研究一下其中的collapse函数。
/** * Collapses a tree to a node if training error doesn't increase. */ public final void collapse(){ double errorsOfSubtree; double errorsOfTree; int i; if (!m_isLeaf){ errorsOfSubtree = getTrainingErrors(); errorsOfTree = localModel().distribution().numIncorrect(); if (errorsOfSubtree >= errorsOfTree-1E-3){ // Free adjacent trees m_sons = null; m_isLeaf = true; // Get NoSplit Model for tree. m_localModel = new NoSplit(localModel().distribution()); }else for (i=0;i<m_sons.length;i++) son(i).collapse(); } }通过注释也可以看出,如果该节点的存在很多孩子节点,但这些孩子节点并不能提高这颗分类树的准确度,则把这些孩子节点删除。否则在每个孩子上递归的坍塌。通过collapse方法可以在不减少精度的前提下减少决策树的深度,进而提高效率。简单说一下如何估计当前的节点的错误,也就是localModel().distribution().numIncorrect();
首先获得当前训练集上的一个分布,然后找出该分布里数量最多的那个属性的数量,认为是“正确的”,则其余的就是错误的。
getTrainingError就是对每个孩子节点做上述操作,然后结果相加。
(5)再来看看prune()方法,也是C45PruneableClassifierTree的BuildClassifier中的最后一个步骤。
该函数比较长,我就直接把对这个函数的分析写在注释里了。
public void prune() throws Exception { double errorsLargestBranch;//这个树节点的孩子节点中,肯定有一个分到的数据最多,该值记录该孩子节点分类错误的用例数 double errorsLeaf;//如果该节点成为了叶子节点,则分类错误的用例数量 double errorsTree;//<span style="font-family: Arial, Helvetica, sans-serif;">该节点目前情况下,错误用例数量</span> int indexOfLargestBranch;//那个分到最多数据的孩子节点在son数组中的index C45PruneableClassifierTree largestBranch;//son[indexOfLargestBranch] int i; if (!m_isLeaf){//首先,如果是叶子节点,则先递归的队所有孩子几点进行prune()。 for (i=0;i<m_sons.length;i++) son(i).prune();//通过数据集的分布,很容易能找到indexOfLargetBranch indexOfLargestBranch = localModel().distribution().maxBag(); if (m_subtreeRaising) {//m_subtreeRaising是一个标志,代表可否使用该树的子树去替代该树,如果有了这个标志,就去计算最大的子树的错误数量//否则就简单的标Double.Max_Value//对于错误数量的估计不展开说了,简单来说依然是根据分布做一个统计(还要加一个基于m_CF的修正),如果不是叶子节点则递//归的进行统计。 errorsLargestBranch = son(indexOfLargestBranch). getEstimatedErrorsForBranch((Instances)m_train); } else { errorsLargestBranch = Double.MAX_VALUE; } //估计一下如果该节点成为了叶子节点,则错误数量大概有多少 errorsLeaf = getEstimatedErrorsForDistribution(localModel().distribution());//估计该节点目前情况下,错误用例数量。 errorsTree = getEstimatedErrors(); //Utils.smOrEq是smaller or equal即<=的意思 if (Utils.smOrEq(errorsLeaf,errorsTree+0.1) && Utils.smOrEq(errorsLeaf,errorsLargestBranch+0.1)){ // 如果当前节点作为叶子节点的错误量比整棵树都要低,并且当前节点比最大的子树的错误量也低,那么就把当前节点作//为叶子节点一定是一个最优的选择。 m_sons = null; m_isLeaf = true; // Get NoSplit Model for node. m_localModel = new NoSplit(localModel().distribution()); return;//直接返回 } // Decide if largest branch is better choice // than whole subtree. if (Utils.smOrEq(errorsLargestBranch,errorsTree+0.1)){//如果当前节点的错误用例数大于最大子树,则用最大子树替代当前节点。 largestBranch = son(indexOfLargestBranch); m_sons = largestBranch.m_sons; m_localModel = largestBranch.localModel(); m_isLeaf = largestBranch.m_isLeaf; newDistribution(m_train); prune(); } } }
一句话总结collapse和prune:prune或许会影响精度,collapse不会。
四、PruneableClassifierTree
在J48主流程里,根据m_reducedErrorPruning的不同会选择两个不同的ClassifierTree,刚才已经分析了一个,另外一个则是PruneeableClassifierTree。
(1)buildClassifier
public void buildClassifier(Instances data) throws Exception { // can classifier tree handle the data? getCapabilities().testWithFail(data); // remove instances with missing class data = http://www.mamicode.com/new Instances(data);>和C45PruneableClassifierTree不同的是,buildTree的时候除了传入训练集,还传入了测试集,除此之外,少了Collapse步骤,其余都一样。下面就看看传入了测试集的build和之前分析的build有什么不同之处。
(2)buildTree
public void buildTree(Instances train, Instances test, boolean keepData) throws Exception { Instances [] localTrain, localTest; int i; if (keepData) { m_train = train; } m_isLeaf = false; m_isEmpty = false; m_sons = null; m_localModel = m_toSelectModel.selectModel(train, test); m_test = new Distribution(test, m_localModel); if (m_localModel.numSubsets() > 1) { localTrain = m_localModel.split(train); localTest = m_localModel.split(test); train = test = null; m_sons = new ClassifierTree [m_localModel.numSubsets()]; for (i=0;i<m_sons.length;i++) { m_sons[i] = getNewTree(localTrain[i], localTest[i]); localTrain[i] = null; localTest[i] = null; } }else{ m_isLeaf = true; if (Utils.eq(train.sumOfWeights(), 0)) m_isEmpty = true; train = test = null; } }可以看到,代码基本一样,唯一不同的地方就是selectModel的时候会把test传进去,对于Model的实现会具体放到下篇博客中去讲述。而prune也更为简单,去掉了subTreeRasing的特性。
public void prune() throws Exception { if (!m_isLeaf) { // Prune all subtrees. for (int i = 0; i < m_sons.length; i++) son(i).prune(); // Decide if leaf is best choice. if (Utils.smOrEq(errorsForLeaf(),errorsForTree())) { // Free son Trees m_sons = null; m_isLeaf = true; // Get NoSplit Model for node. m_localModel = new NoSplit(localModel().distribution()); } } }五、总结
至此,对两种ClassifierTree的buildClassifier的分析差不多就结束了,总体上来讲,ClassifierTree是通过传入的Model来构建并维护分类树的结构,除此之外在构建完毕后会按照不同的逻辑进行剪枝。
对于篇开头提出的问题,目前可以回答问题4,简而言之就是根据已有数据集的分布,判断该树、该树的最大子树、以及该树作为叶子节点时的正确率,在此基础上进行剪枝。
下篇文章主要分析Model的实现,也就是如何根据属性把已有的数据集分解subInstances
Weka算法Classifier-tree-J48源码分析(二)ClassifierTree