首页 > 代码库 > 决策树归纳(ID3属性选择度量)Java实现
决策树归纳(ID3属性选择度量)Java实现
一般的决策树归纳框架见之前的博文:http://blog.csdn.net/zhyoulun/article/details/41978381
ID3属性选择度量原理
ID3使用信息增益作为属性选择度量。该度量基于香农在研究消息的值或”信息内容“的信息论方面的先驱工作。该结点N代表或存放分区D的元组。选择具有最高信息增益的属性作为结点N的分裂属性。该属性使结果分区中对元祖分类所需要的信息量最小,并反映这些分区中的最小随机性或”不纯性“。这种方法使得对一个对象分类所需要的期望测试数目最小,并确保找到一颗简单的(但不必是最简单的)树。
对D中的元组分类所需要的期望信息由下式给出,
其中pi是D忠任意元组属于类Ci的非零概率。使用以2为底的对数函数是因为信息用二进制编码。Info(D)是识别D中元组的类标号所需要的平均信息量。注意,此时我们所有的信息只是每个类的元组所占的百分比。
现在假设我们要按照某属性A划分D中的元组,其中属性A根据训练数据的观测具有v个不同的值{a1,a2,...av}。可以用属性A将D划分为v个分区或子集{D1,D2,...,Dv},其中Dj包含D中的元组,它们的A值为aj。这些分区对应于从节点N生长出来的分支。理想情况下,我们希望该划分产生元组的准确分类。即希望每个分区都是纯的(实际情况多半是不纯的,如分区可能包含来自不同类的元组)。在此划分之后,为了得到准确的分类,我们还需要多少信息?这个量由下式度量:
其中|Dj|/|D|充当第j个分区的权重。Info_A(D)是基于按A划分对D的元组分类所需要的期望值信息。需要的期望信息越小,分区的纯度越高。
信息增益定义为原来的信息需求(仅基于类比例)与新的信息需求(对A划分后)之前的差。即
换言之,Gain(A)告诉我们通过A上的划分我们得到了多少。它是知道A的值而导致的信息需求的期望减少。选择具有最高信息增益Gain(A)的属性A作为结点N的分裂属性。
以下为例子。
数据
data.txt
youth,high,no,fair,no youth,high,no,excellent,no middle_aged,high,no,fair,yes senior,medium,no,fair,yes senior,low,yes,fair,yes senior,low,yes,excellent,no middle_aged,low,yes,excellent,yes youth,medium,no,fair,no youth,low,yes,fair,yes senior,medium,yes,fair,yes youth,medium,yes,excellent,yes middle_aged,medium,no,excellent,yes middle_aged,high,yes,fair,yes senior,medium,no,excellent,no
attr.txt
age,income,student,credit_rating,buys_computer
运算结果
age(1:youth; 2:middle_aged; 3:senior; ) credit_rating(1:fair; 2:excellent; ) leaf:no() leaf:yes() leaf:yes() student(1:no; 2:yes; ) leaf:no() leaf:yes()
最后附上java代码
DecisionTree.java
package com.zhyoulun.decision; import java.io.BufferedReader; import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.Map; /** * 负责数据的读入和写出,以及生成决策树 * * @author zhyoulun * */ public class DecisionTree { private ArrayList<ArrayList<String>> allDatas; private ArrayList<String> allAttributes; /** * 从文件中读取所有相关数据 * @param dataFilePath * @param attrFilePath */ public DecisionTree(String dataFilePath,String attrFilePath) { super(); try { this.allDatas = new ArrayList<>(); this.allAttributes = new ArrayList<>(); InputStreamReader inputStreamReader = new InputStreamReader(new FileInputStream(new File(dataFilePath))); BufferedReader bufferedReader = new BufferedReader(inputStreamReader); String line = null; while((line=bufferedReader.readLine())!=null) { String[] strings = line.split(","); ArrayList<String> data = http://www.mamicode.com/new ArrayList<>();>
CriterionID3.java
package com.zhyoulun.decision; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; /** * ID3,选择分裂准则 * * @author zhyoulun * */ public class CriterionID3 { private ArrayList<ArrayList<String>> datas; private ArrayList<String> attributes; private Map<String, ArrayList<ArrayList<String>>> subDatasMap; /** * 计算所有的信息增益,获取最大的一项作为分裂属性 * @return */ public int attributeSelectionMethod() { double gain = -1.0; int maxIndex = 0; for(int i=0;i<this.attributes.size()-1;i++) { double tempGain = this.calcGain(i); if(tempGain>gain) { gain = tempGain; maxIndex = i; } } return maxIndex; } /** * 计算 Gain(age)=Info(D)-Info_age(D) 等 * @param index * @return */ /** * @param index * @param isCalcSubDatasMap * @return */ private double calcGain(int index) { double result = 0; //计算Info(D) int lastIndex = datas.get(0).size()-1; ArrayList<String> valueSet = DecisionTree.getValueSet(this.datas,lastIndex); for(String value:valueSet) { int count = 0; for(int i=0;i<datas.size();i++) { if(datas.get(i).get(lastIndex).equals(value)) count++; } result += -(1.0*count/datas.size())*Math.log(1.0*count/datas.size())/Math.log(2); // System.out.println(result); } // System.out.println("=========="); //计算Info_a(D) valueSet = DecisionTree.getValueSet(this.datas,index); // for(String temp:valueSet) // System.out.println(temp); // System.out.println("=========="); for(String value:valueSet) { ArrayList<ArrayList<String>> subDatas = new ArrayList<>(); for(int i=0;i<datas.size();i++) { if(datas.get(i).get(index).equals(value)) subDatas.add(datas.get(i)); } // for(ArrayList<String> temp:subDatas) // { // for(String temp2:temp) // System.out.print(temp2+" "); // System.out.println(); // } ArrayList<String> subValueSet = DecisionTree.getValueSet(subDatas, lastIndex); // System.out.print("subValueSet:"); // for(String temp:subValueSet) // System.out.print(temp+" "); // System.out.println(); for(String subValue:subValueSet) { // System.out.println("+++++++++++++++"); // System.out.println(subValue); int count = 0; for(int i=0;i<subDatas.size();i++) { if(subDatas.get(i).get(lastIndex).equals(subValue)) count++; } // System.out.println(count); result += -1.0*subDatas.size()/datas.size()*(-(1.0*count/subDatas.size())*Math.log(1.0*count/subDatas.size())/Math.log(2)); // System.out.println(result); } } return result; } public CriterionID3(ArrayList<ArrayList<String>> datas, ArrayList<String> attributes) { super(); this.datas = datas; this.attributes = attributes; } public ArrayList<ArrayList<String>> getDatas() { return datas; } public void setDatas(ArrayList<ArrayList<String>> datas) { this.datas = datas; } public ArrayList<String> getAttributes() { return attributes; } public void setAttributes(ArrayList<String> attributes) { this.attributes = attributes; } public Map<String, ArrayList<ArrayList<String>>> getSubDatasMap(int index) { ArrayList<String> valueSet = DecisionTree.getValueSet(this.datas, index); this.subDatasMap = new HashMap<String, ArrayList<ArrayList<String>>>(); for(String value:valueSet) { ArrayList<ArrayList<String>> subDatas = new ArrayList<>(); for(int i=0;i<this.datas.size();i++) { if(this.datas.get(i).get(index).equals(value)) subDatas.add(this.datas.get(i)); } for(int i=0;i<subDatas.size();i++) { subDatas.get(i).remove(index); } this.subDatasMap.put(value, subDatas); } return subDatasMap; } public void setSubDatasMap(Map<String, ArrayList<ArrayList<String>>> subDatasMap) { this.subDatasMap = subDatasMap; } }
TreeNode.java
package com.zhyoulun.decision; import java.util.ArrayList; public class TreeNode { private String name; // 该结点的名称(分裂属性) private ArrayList<String> rules; // 结点的分裂规则(假设均为离散值) // private ArrayList<ArrayList<String>> datas; // 划分到该结点的训练元组(datas.get(i)表示一个训练元组) // private ArrayList<String> candidateAttributes; // 划分到该结点的候选属性(与训练元组的个数一致) private ArrayList<TreeNode> children; // 子结点 public TreeNode() { this.name = ""; this.rules = new ArrayList<String>(); this.children = new ArrayList<TreeNode>(); // this.datas = null; // this.candidateAttributes = null; } public String getName() { return name; } public void setName(String name) { this.name = name; } public ArrayList<String> getRules() { return rules; } public void setRules(ArrayList<String> rules) { this.rules = rules; } public ArrayList<TreeNode> getChildren() { return children; } public void setChildren(ArrayList<TreeNode> children) { this.children = children; } // public ArrayList<ArrayList<String>> getDatas() // { // return datas; // } // // public void setDatas(ArrayList<ArrayList<String>> datas) // { // this.datas = datas; // } // // public ArrayList<String> getCandidateAttributes() // { // return candidateAttributes; // } // // public void setCandidateAttributes(ArrayList<String> candidateAttributes) // { // this.candidateAttributes = candidateAttributes; // } }
参考:《数据挖掘概念与技术(第3版)》
转载请注明出处:
决策树归纳(ID3属性选择度量)Java实现