首页 > 代码库 > 决策树归纳(ID3属性选择度量)Java实现

决策树归纳(ID3属性选择度量)Java实现

一般的决策树归纳框架见之前的博文:http://blog.csdn.net/zhyoulun/article/details/41978381


ID3属性选择度量原理

ID3使用信息增益作为属性选择度量。该度量基于香农在研究消息的值或”信息内容“的信息论方面的先驱工作。该结点N代表或存放分区D的元组。选择具有最高信息增益的属性作为结点N的分裂属性。该属性使结果分区中对元祖分类所需要的信息量最小,并反映这些分区中的最小随机性或”不纯性“。这种方法使得对一个对象分类所需要的期望测试数目最小,并确保找到一颗简单的(但不必是最简单的)树。

对D中的元组分类所需要的期望信息由下式给出,

技术分享

其中pi是D忠任意元组属于类Ci的非零概率。使用以2为底的对数函数是因为信息用二进制编码。Info(D)是识别D中元组的类标号所需要的平均信息量。注意,此时我们所有的信息只是每个类的元组所占的百分比。

现在假设我们要按照某属性A划分D中的元组,其中属性A根据训练数据的观测具有v个不同的值{a1,a2,...av}。可以用属性A将D划分为v个分区或子集{D1,D2,...,Dv},其中Dj包含D中的元组,它们的A值为aj。这些分区对应于从节点N生长出来的分支。理想情况下,我们希望该划分产生元组的准确分类。即希望每个分区都是纯的(实际情况多半是不纯的,如分区可能包含来自不同类的元组)。在此划分之后,为了得到准确的分类,我们还需要多少信息?这个量由下式度量:

技术分享

其中|Dj|/|D|充当第j个分区的权重。Info_A(D)是基于按A划分对D的元组分类所需要的期望值信息需要的期望信息越小,分区的纯度越高

信息增益定义为原来的信息需求(仅基于类比例)与新的信息需求(对A划分后)之前的差。即

技术分享

换言之,Gain(A)告诉我们通过A上的划分我们得到了多少。它是知道A的值而导致的信息需求的期望减少。选择具有最高信息增益Gain(A)的属性A作为结点N的分裂属性。


以下为例子。


数据

data.txt

youth,high,no,fair,no
youth,high,no,excellent,no
middle_aged,high,no,fair,yes
senior,medium,no,fair,yes
senior,low,yes,fair,yes
senior,low,yes,excellent,no
middle_aged,low,yes,excellent,yes
youth,medium,no,fair,no
youth,low,yes,fair,yes
senior,medium,yes,fair,yes
youth,medium,yes,excellent,yes
middle_aged,medium,no,excellent,yes
middle_aged,high,yes,fair,yes
senior,medium,no,excellent,no


attr.txt

age,income,student,credit_rating,buys_computer


运算结果

age(1:youth; 2:middle_aged; 3:senior; )
	credit_rating(1:fair; 2:excellent; )
		leaf:no()
		leaf:yes()
	leaf:yes()
	student(1:no; 2:yes; )
		leaf:no()
		leaf:yes()


技术分享


最后附上java代码

DecisionTree.java

package com.zhyoulun.decision;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Map;

/**
 * 负责数据的读入和写出,以及生成决策树
 * 
 * @author zhyoulun
 *
 */
public class DecisionTree
{
	private ArrayList<ArrayList<String>> allDatas;
	private ArrayList<String> allAttributes;
	
	/**
	 * 从文件中读取所有相关数据
	 * @param dataFilePath
	 * @param attrFilePath
	 */
	public DecisionTree(String dataFilePath,String attrFilePath)
	{
		super();
		
		try
		{
			this.allDatas = new ArrayList<>();
			this.allAttributes = new ArrayList<>();
			
			InputStreamReader inputStreamReader = new InputStreamReader(new FileInputStream(new File(dataFilePath)));
			BufferedReader bufferedReader = new BufferedReader(inputStreamReader);
			String line = null;
			while((line=bufferedReader.readLine())!=null)
			{
				String[] strings = line.split(",");
				ArrayList<String> data = http://www.mamicode.com/new ArrayList<>();>


CriterionID3.java

package com.zhyoulun.decision;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

/**
 * ID3,选择分裂准则
 * 
 * @author zhyoulun
 *
 */
public class CriterionID3
{
	private ArrayList<ArrayList<String>> datas;
	private ArrayList<String> attributes;
	
	private Map<String, ArrayList<ArrayList<String>>> subDatasMap;
	
	/**
	 * 计算所有的信息增益,获取最大的一项作为分裂属性
	 * @return
	 */
	public int attributeSelectionMethod()
	{
		double gain = -1.0;
		int maxIndex = 0;
		for(int i=0;i<this.attributes.size()-1;i++)
		{
			double tempGain = this.calcGain(i);
			if(tempGain>gain)
			{
				gain = tempGain;
				maxIndex = i;
			}
		}
		
		return maxIndex;
	}
	
	/**
	 * 计算 Gain(age)=Info(D)-Info_age(D) 等
	 * @param index
	 * @return
	 */
	/**
	 * @param index
	 * @param isCalcSubDatasMap
	 * @return
	 */
	private double calcGain(int index)
	{
		double result = 0;
		
		//计算Info(D)
		int lastIndex = datas.get(0).size()-1;
		ArrayList<String> valueSet = DecisionTree.getValueSet(this.datas,lastIndex);
		for(String value:valueSet)
		{
			int count = 0;
			for(int i=0;i<datas.size();i++)
			{
				if(datas.get(i).get(lastIndex).equals(value))
					count++;
			}
			
			result += -(1.0*count/datas.size())*Math.log(1.0*count/datas.size())/Math.log(2);
//			System.out.println(result);
		}
//		System.out.println("==========");
		
		//计算Info_a(D)
		valueSet = DecisionTree.getValueSet(this.datas,index);
		
//		for(String temp:valueSet)
//			System.out.println(temp);
//		System.out.println("==========");
		
		for(String value:valueSet)
		{	
			ArrayList<ArrayList<String>> subDatas = new ArrayList<>();
			for(int i=0;i<datas.size();i++)
			{
				if(datas.get(i).get(index).equals(value))
					subDatas.add(datas.get(i));
			}
			
			
			
			
//			for(ArrayList<String> temp:subDatas)
//			{
//				for(String temp2:temp)
//					System.out.print(temp2+" ");
//				System.out.println();
//			}
			
			ArrayList<String> subValueSet = DecisionTree.getValueSet(subDatas, lastIndex);
			
			
//			System.out.print("subValueSet:");
//			for(String temp:subValueSet)
//				System.out.print(temp+" ");
//			System.out.println();
			
			
			for(String subValue:subValueSet)
			{
//				System.out.println("+++++++++++++++");
//				System.out.println(subValue);
				int count = 0;
				for(int i=0;i<subDatas.size();i++)
				{
					if(subDatas.get(i).get(lastIndex).equals(subValue))
						count++;
				}
//				System.out.println(count);
				result += -1.0*subDatas.size()/datas.size()*(-(1.0*count/subDatas.size())*Math.log(1.0*count/subDatas.size())/Math.log(2));
//				System.out.println(result);
			}
			
		}
		
		return result;
		
	}



	public CriterionID3(ArrayList<ArrayList<String>> datas,
			ArrayList<String> attributes)
	{
		super();
		this.datas = datas;
		this.attributes = attributes;
	}



	public ArrayList<ArrayList<String>> getDatas()
	{
		return datas;
	}



	public void setDatas(ArrayList<ArrayList<String>> datas)
	{
		this.datas = datas;
	}



	public ArrayList<String> getAttributes()
	{
		return attributes;
	}



	public void setAttributes(ArrayList<String> attributes)
	{
		this.attributes = attributes;
	}

	public Map<String, ArrayList<ArrayList<String>>> getSubDatasMap(int index)
	{
		ArrayList<String> valueSet = DecisionTree.getValueSet(this.datas, index);
		this.subDatasMap = new HashMap<String, ArrayList<ArrayList<String>>>();
		
		for(String value:valueSet)
		{
			ArrayList<ArrayList<String>> subDatas = new ArrayList<>();
			for(int i=0;i<this.datas.size();i++)
			{
				if(this.datas.get(i).get(index).equals(value))
					subDatas.add(this.datas.get(i));
			}
			for(int i=0;i<subDatas.size();i++)
			{
				subDatas.get(i).remove(index);
			}
			this.subDatasMap.put(value, subDatas);
		}
		
		return subDatasMap;
	}

	public void setSubDatasMap(Map<String, ArrayList<ArrayList<String>>> subDatasMap)
	{
		this.subDatasMap = subDatasMap;
	}
	
	
}



TreeNode.java

package com.zhyoulun.decision;

import java.util.ArrayList;

public class TreeNode
{
	private String name; 								// 该结点的名称(分裂属性)
	private ArrayList<String> rules; 				// 结点的分裂规则(假设均为离散值)
//	private ArrayList<ArrayList<String>> datas; 	// 划分到该结点的训练元组(datas.get(i)表示一个训练元组)
//	private ArrayList<String> candidateAttributes; // 划分到该结点的候选属性(与训练元组的个数一致)
	private ArrayList<TreeNode> children; 			// 子结点

	public TreeNode()
	{
		this.name = "";
		this.rules = new ArrayList<String>();
		this.children = new ArrayList<TreeNode>();
//		this.datas = null;
//		this.candidateAttributes = null;
	}

	public String getName()
	{
		return name;
	}

	public void setName(String name)
	{
		this.name = name;
	}

	public ArrayList<String> getRules()
	{
		return rules;
	}

	public void setRules(ArrayList<String> rules)
	{
		this.rules = rules;
	}

	public ArrayList<TreeNode> getChildren()
	{
		return children;
	}

	public void setChildren(ArrayList<TreeNode> children)
	{
		this.children = children;
	}

//	public ArrayList<ArrayList<String>> getDatas()
//	{
//		return datas;
//	}
//
//	public void setDatas(ArrayList<ArrayList<String>> datas)
//	{
//		this.datas = datas;
//	}
//
//	public ArrayList<String> getCandidateAttributes()
//	{
//		return candidateAttributes;
//	}
//
//	public void setCandidateAttributes(ArrayList<String> candidateAttributes)
//	{
//		this.candidateAttributes = candidateAttributes;
//	}
	
	
}



参考:《数据挖掘概念与技术(第3版)》

转载请注明出处:

决策树归纳(ID3属性选择度量)Java实现