首页 > 代码库 > ID3算法

ID3算法

一、ID3算法简单介绍

最早起源于《罗斯昆ID3在悉尼大学。他第一次提出的ID3 1975年在一本书、机器学习、研究所硕士论文。ID3是建立了概念学习系统(CLS)算法。ID3算法是一种基于决策树的算法。决策树由决策结点、分支和叶子组成。决策树中最上面的结点为根节点,每个分支是一个新的决策结点,或者是树的叶子。每个决策结点代表一个问题或决策,通常对应于待分类对象的属性。每一个叶子节点代表一种可能的分类结果。沿决策树从上到下遍历的过程中,在每个结点都会遇到一个测试,对每个结点上问题的不同的测试输出导致不同的分支,最后会到达一个叶子节点,这个过程就是利用决策树进行分类的过程,利用若干个变量来判断所属的类别。

二、ID3算法基础--信息论

 ID3算法是一信息论为基础,以信息熵和信息增益度为衡量指标,从而实现对数据的分类操作。下面给出一些信息论中的基本概念:

定义1:若存在n个相同概率的消息,则每个消息的概率p是1/n,一个消息传递的消息量为-Log2(1/n)

定义2:若有n个消息,其给定概率分布为P=(p1,p2,....pn),则由该分布传递的消息量称为P的熵,记为I(P)=-p1*Log2(p1)-p2*Log2(p2)-...-pn*Log2(pn).

定义3:若一个记录集合T根据类别属性的值被分成互相独立的类C1,C2.....Ck,那么识别T的一个所属那个类型需要的信息量为Info(T)=I(P),其中P为C1,C2....Ck的概率分布,即P=(|C1|/|T|,|C2|/|T|,....|Ck|/|T|)。

定义4:若我们先根据非类别属性X的值将T的值分成集合T1,T2,T3....Tn,则确定T中一个元素类的信息量可通过确定Ti的加权平均值来得到,即Info(Ti)的加权平均值为:Info(X,T)=(i=1 to n求和)((|Ti|/|T|)Info(Ti))

定义5:信息增益度是两个信息量之间的差值,其中一个信息量是确定T的一个元素的信息量,另一个信息量是在得到一个确定属性X的值后需要确定T一个元素的信息量,公式为:Gain(X,T) = Info(T) = Info(X,T).

ID3算法计算每个属性的信息增益,并选择具有最高增益的属性作为给定集合的测试属性。对被选择的属性创建一个节点,并记录该节点的属性标记,对该属性的每一个值创建一个分支,并对分支进行迭代循环计算信息增益操作。

三、ID3算法步骤示例

 下面给定一个ID3算法的示例:

RIDageincomestudentcredit_ratingbuy_compter
1youthhighnofairno
2youthhighnoexcellentno
3middle_agedhighnofairyes
4seniormediumnofairyes
5seniorlowyesfairyes
6seniorlowyesexcellentno
7middle_agedlowyesexcellentyes
8youthmediumnofairno
9youthlowyesfairyes
10seniormediumyesfairyes
11youthmediumyesexcellentyes
12middle_agedmediumnoexcellentyes
13middle_agedhighyesfairyes
14seniormediumnoexcellentno

 

总数据量是14条,参考属性是age(youth[5], middle_aged[4], senior[5]),income(high[4], medium[6], low[4]), student(no[7], yes[7]), credit_rating(fair[8], excellent[6])。目标属性是bug_computer(no[5], yes[9]),希望的结果是能够得到一个根据age, income, student, credit_rating来推测出来buy_computer的值。假设初始数据集D,参考属性列表A,下面给定计算步骤:

第一步:在数据集D中就目标属性的信息熵: Info(buy_computer) = -(5/14)*log2(5/14)-(9/14)*log2(9/14)=0.94

第二步:在数据集D中就参考属性列表A中的每一个属性计算,在该属性值确定的条件下,确定一个bug_computer的信息熵,也就是条件熵。

  age属性:youth(no[3],yes[2]),middle_aged(no[0],yes[4]),senior(no[2],yes[3]),先分别计算youth、middle_aged、senior的信息熵。

    Infoage(bug_computer|youth) = -(3/5)*log2(3/5) - (2/5)*log2(2/5) = 0.971

    Infoage(bug_computer|middle_aged) = -(4/4)*log2(4/4) - (0/5)*log2(0/5) = 0

    Infoage(bug_computer|senior) = -(2/5)*log2(2/5) - (3/5)*log2(3/5) = 0.971

  则Infoage(buy_computer) = 5/14*0.971 + 4/14 * 0+ 5/14 * 0.971 = 0.694

  同理:Infoincome(buy_computer) = 0.911;Infostudent(buy_computer) = 0.789;Infocredit_rating(buy_computer) = 0.892.

第三步,计算信息增益度,该值如果越大,表示目标属性在该参考属性上失去的信息熵越多,那么该属性就越应该在决策树的上层。计算结果为:

  Gain(age,bug_computer) = Info(buy_computer) - Infoage(buy_computer) = 0.94 - 0.694 = 0.246

  Gain(income,bug_computer) = Info(buy_computer) - Infoicome(buy_computer) = 0.94 - 0.911 = 0.029

  Gain(student,bug_computer) = Info(buy_computer) - Infostudent(buy_computer) = 0.94 - 0.789 = 0.151

  Gain(credit_rating,bug_computer) = Info(buy_computer) - Infocredit_rating(buy_computer) = 0.94 - 0.892 = 0.048

第四步,选择信息增益度最大的属性作为当前节点,此时是age,根据age的不同取值将初始数据集D分隔成以下情况。

  1. age为youth的时候,子数据集是D1:

RIDincomestudentcredit_ratingbuy_computer
1highnofairno
2highnoexcellentno
8mediumnofairno
9lowyesfairyes
11highyesexcellentyes

  2. age为middle_aged的时候,子数据集是D2:

RIDincomestudentcredit_ratingbuy_computer
3highnofairyes
7lowyesexcellentyes
12mediumnoexcellentyes
13highyesfairyes

  3. age为senior的时候,子数据集是D3:

RIDincomestudentcredit_ratingbuy_computer
4mediumnofairyes
5lowyesfairyes
6lowyesexcellentno
10mediumyesfairyes
14mediumnoexcellentno

 

第五步,将已经选择的参考属性(age)从参考属性列表A中剔除,针对第四步中产生的子数据集Di使用处理后的参考属性列表A,再从第一步迭代处理。迭代结束条件为:

  1. 当某种分类中,目标属性只有一个值,如这里当age为middle_aged的时候。
  2. 当分到某类的时候,目标属性所有值中,某个值的比例达到了阈值(人为控制),比如可以设为只要buy_computer中某个值达到90%以上,就可以结束迭代。

经过多次迭戈处理,最终会得到一个树结构如下图所示:

技术分享

获得规则是:

IF AGE=middle_aged, THEN buy_computer = yes

IF AGE = youth AND STUDENT = yes, THEN buy_computer = yes

IF AGE = youth AND STUDENT = no, THEN buy_computer = no

IF AGE = senior AND CREDIT_RATING = excellent, THEN buy_computer = no

IF AGE = senior AND CREDIT_RATING = fair, THEN buy_computer = yes

SO, If the instance are ("15", "youth", "medium", "yes", "fair"), the predicted value of buy_computer is "yes".

 

四、ID3算法程序实现

下面分别给出python和java两种语言的ID3算法的实现:

Python程序:

技术分享
  1 # -*- coding: utf-8 -*-  2   3   4 class Node:  5     ‘‘‘Represents a decision tree node.  6       7     ‘‘‘  8     def __init__(self, parent = None, dataset = None):  9         self.dataset = dataset # 落在该结点的训练实例集 10         self.result = None # 结果类标签 11         self.attr = None # 该结点的分裂属性ID 12         self.childs = {} # 该结点的子树列表,key-value pair: (属性attr的值, 对应的子树) 13         self.parent = parent # 该结点的父亲结点 14          15  16  17 def entropy(props): 18     if (not isinstance(props, (tuple, list))): 19         return None 20      21     from math import log 22     log2 = lambda x:log(x)/log(2) # an anonymous function 23     e = 0.0 24     for p in props: 25         if p != 0: 26             e = e - p * log2(p) 27     return e 28  29  30 def info_gain(D, A, T = -1, return_ratio = False): 31     ‘‘‘特征A对训练数据集D的信息增益 g(D,A) 32      33     g(D,A)=entropy(D) - entropy(D|A) 34             假设数据集D的每个元组的最后一个特征为类标签 35     T为目标属性的ID,-1表示元组的最后一个元素为目标‘‘‘ 36     if (not isinstance(D, (set, list))): 37         return None 38     if (not type(A) is int): 39         return None 40     C = {} # 类别计数字典 41     DA = {} # 特征A的取值计数字典 42     CDA = {} # 类别和特征A的不同组合的取值计数字典 43     for t in D: 44         C[t[T]] = C.get(t[T], 0) + 1 45         DA[t[A]] = DA.get(t[A], 0) + 1 46         CDA[(t[T], t[A])] = CDA.get((t[T], t[A]), 0) + 1 47  48     PC = map(lambda x : 1.0 * x / len(D), C.values()) # 类别的概率列表,即目标属性的概率,信息熵 49     entropy_D = entropy(tuple(PC)) # map返回的对象类型为map,需要强制类型转换为元组 50  51  52     PCDA = {} # 特征A的每个取值给定的条件下各个类别的概率(条件概率) 53     for key, value in CDA.items(): 54         a = key[1] # 特征A 55         pca = value / DA[a] 56         PCDA.setdefault(a, []).append(pca) 57      58     condition_entropy = 0.0 59     for a, v in DA.items(): 60         p = v / len(D) 61         e = entropy(PCDA[a]) 62         condition_entropy += e * p 63      64     if (return_ratio): 65         return (entropy_D - condition_entropy) / entropy_D 66     else: 67         return entropy_D - condition_entropy 68      69 def get_result(D, T = -1): 70     ‘‘‘获取数据集D中实例数最大的目标特征T的值‘‘‘ 71     if (not isinstance(D, (set, list))): 72         return None 73     if (not type(T) is int): 74         return None 75     count = {} 76     for t in D: 77         count[t[T]] = count.get(t[T], 0) + 1 78     max_count = 0 79     for key, value in count.items(): 80         if (value > max_count): 81             max_count = value 82             result = key 83     return result  84  85  86 def devide_set(D, A): 87     ‘‘‘根据特征A的值把数据集D分裂为多个子集‘‘‘ 88     if (not isinstance(D, (set, list))): 89         return None 90     if (not type(A) is int): 91         return None 92     subset = {} 93     for t in D: 94         subset.setdefault(t[A], []).append(t) 95     return subset 96  97  98 def build_tree(D, A, threshold = 0.0001, T = -1, Tree = None, algo = "ID3"): 99     ‘‘‘根据数据集D和特征集A构建决策树.100     101     T为目标属性在元组中的索引 . 目前支持ID3和C4.5两种算法‘‘‘102     if (Tree != None and not isinstance(Tree, Node)):103         return None104     if (not isinstance(D, (set, list))):105         return None106     if (not type(A) is set):107         return None108     109     if (None == Tree):110         Tree = Node(None, D)111     subset = devide_set(D, T)112     if (len(subset) <= 1):113         for key in subset.keys():114             Tree.result = key115         del(subset)116         return Tree117     if (len(A) <= 0):118         Tree.result = get_result(D)119         return Tree120     use_gain_ratio = False if algo == "ID3" else True121 122     max_gain = 0123     for a in A:124         gain = info_gain(D, a, return_ratio = use_gain_ratio)125         if (gain > max_gain):126             max_gain = gain127             attr_id = a # 获取信息增益最大的特征128     if (max_gain < threshold):129         Tree.result = get_result(D)130         return Tree131     Tree.attr = attr_id132     subD = devide_set(D, attr_id)133     del(D[:]) # 删除中间数据,释放内存134     Tree.dataset = None135     A.discard(attr_id) # 从特征集中排查已经使用过的特征136     for key in subD.keys():137         tree = Node(Tree, subD.get(key))138         Tree.childs[key] = tree139         build_tree(subD.get(key), A, threshold, T, tree)140     return Tree141 142 143 def print_brance(brance, target):144     odd = 0145     for e in brance:146         print e,(= if odd == 0 else ),147         odd = 1 - odd148     print "target =", target149 150 151 def print_tree(Tree, stack = []): 152     if (None == Tree):153         return154     if (None != Tree.result):155         print_brance(stack, Tree.result)156         return157     stack.append(Tree.attr)158     for key, value in Tree.childs.items():159         stack.append(key)160         print_tree(value, stack)161         stack.pop()162     stack.pop()163     164 def classify(Tree, instance):165     if (None == Tree):166         return None167     if (None != Tree.result):168         return Tree.result169     if instance[Tree.attr] in Tree.childs:170         return classify(Tree.childs[instance[Tree.attr]], instance)171     else:172         return None173 174 dataset = [175    ("青年", "", "", "一般", "")176    ,("青年", "", "", "", "")177    ,("青年", "", "", "", "")178    ,("青年", "", "", "一般", "")179    ,("青年", "", "", "一般", "")180    ,("中年", "", "", "一般", "")181    ,("中年", "", "", "", "")182    ,("老年", "", "", "非常好", "")183    ,("老年", "", "", "一般", "")184    ,("老年", "", "", "一般", "")185    ,("老年", "", "", "一般", "")186    ,("老年", "", "", "", "")187    ,("老年", "", "", "一般", "")188    ,("老年", "", "", "一般", "")189    ,("老年", "", "", "一般", "")190 ]191 192 s = set(range(0, len(dataset[0]) - 1))193 s = set([0,1,3,4])194 T = build_tree(dataset, s)195 print_tree(T)196 print(classify(T, ("老年", "", "", "一般", "")))197 print(classify(T, ("老年", "", "", "一般", "")))198 print(classify(T, ("老年", "", "", "", "")))199 print(classify(T, ("青年", "", "", "", "")))200 print(classify(T, ("中年", "", "", "", "")))
ID3--Python

 

该python程序的训练集不是上面给定的这个列子,输出结果为:

0 = 青年 ∧ 1 = 否 ∧ target = 否0 = 青年 ∧ 1 = 是 ∧ target = 是0 = 中年 ∧ target = 否0 = 老年 ∧ 3 = 好 ∧ target = 是0 = 老年 ∧ 3 = 非常好 ∧ target = 是0 = 老年 ∧ 3 = 一般 ∧ 4 = 否 ∧ target = 否0 = 老年 ∧ 3 = 一般 ∧ 4 = 是 ∧ target = 是否是是是否[Finished in 0.3s]

  

Java程序,该程序的数据集是上面给定的例子,代码及结果如下:

技术分享
  1   2   3 import java.util.ArrayList;  4 import java.util.Collection;  5 import java.util.Deque;  6 import java.util.HashMap;  7 import java.util.LinkedList;  8 import java.util.List;  9 import java.util.Map; 10  11 public class ID3Tree { 12     private List<String[]> datas; 13     private List<Integer> attributes; 14     private double threshold = 0.0001; 15     private int targetIndex = 1; 16     private Node tree; 17     private Map<Integer, String> attributeMap; 18  19     protected ID3Tree() { 20         super(); 21     } 22  23     public ID3Tree(List<String[]> datas, List<Integer> attributes, Map<Integer, String> attributeMap, int targetIndex) { 24         this(datas, attributes, attributeMap, 0.0001, targetIndex, null); 25     } 26  27     public ID3Tree(List<String[]> datas, List<Integer> attributes, Map<Integer, String> attributeMap, double threshold, int targetIndex, Node tree) { 28         super(); 29         this.datas = datas; 30         this.attributes = attributes; 31         this.attributeMap = attributeMap; 32         this.threshold = threshold; 33         this.targetIndex = targetIndex; 34         this.tree = tree; 35     } 36  37     /** 38      * 节点对象 39      *  40      * @author Gerry.Liu 41      * 42      */ 43     class Node { 44         private List<String[]> dataset; // 落在该节点上的训练实训集 45         private String result; // 结果类标签 46         private int attr; // 该节点的分裂属性ID,下标 47         private Node parent; // 该节点的父节点 48         private Map<String, List<Node>> childs; // 该节点的子节点集合 49  50         public Node(List<String[]> datas, Node parent) { 51             this.dataset = datas; 52             this.parent = parent; 53             this.childs = new HashMap<>(); 54         } 55     } 56  57     class KeyValue { 58         private String first; 59         private String second; 60  61         public KeyValue(String first, String second) { 62             super(); 63             this.first = first; 64             this.second = second; 65         } 66  67         @Override 68         public int hashCode() { 69             final int prime = 31; 70             int result = 1; 71             result = prime * result + getOuterType().hashCode(); 72             result = prime * result + ((first == null) ? 0 : first.hashCode()); 73             result = prime * result + ((second == null) ? 0 : second.hashCode()); 74             return result; 75         } 76  77         @Override 78         public boolean equals(Object obj) { 79             if (this == obj) 80                 return true; 81             if (obj == null) 82                 return false; 83             if (getClass() != obj.getClass()) 84                 return false; 85             KeyValue other = (KeyValue) obj; 86             if (!getOuterType().equals(other.getOuterType())) 87                 return false; 88             if (first == null) { 89                 if (other.first != null) 90                     return false; 91             } else if (!first.equals(other.first)) 92                 return false; 93             if (second == null) { 94                 if (other.second != null) 95                     return false; 96             } else if (!second.equals(other.second)) 97                 return false; 98             return true; 99         }100 101         private ID3Tree getOuterType() {102             return ID3Tree.this;103         }104     }105 106     /**107      * 根据概率计算信息熵,计算规则是:<br/>108      * entropy(p1,p2....pn) = -p1*log2(p1) -p2*log2(p2)-.....-pn*log2(pn)109      *110      * @param props111      * @return112      */113     private double entropy(List<Double> props) {114         if (props == null || props.isEmpty()) {115             return 0;116         } else {117             double result = 0;118             for (double p : props) {119                 if (p > 0) {120                     result = result - p * Math.log(p) / Math.log(2);121                 }122             }123             return result;124         }125     }126 127     /**128      * 计算概率129      * 130      * @param totalRecords131      * @param counts132      * @return133      */134     private List<Double> calcProbability(int totalRecords, Collection<Integer> counts) {135         if (totalRecords == 0 || counts == null || counts.isEmpty()) {136             return null;137         }138 139         List<Double> result = new ArrayList<>();140         for (int count : counts) {141             result.add(1.0 * count / totalRecords);142         }143         return result;144     }145 146     /**147      * 获取信息增益Gain(datas,attribute)<br/>148      * 特征属性attribute(A)对训练数据集datas(D)的信息增益<br/>149      * g(D,A) = entropy(D) - entropy(D|A)<br/>150      * 151      * @param datas152      *            训练数据集153      * @param attributeIndex154      *            特征属性下标155      * @param targetAttributeIndex156      *            目标属性下标157      * @return158      */159     private double infoGain(List<String[]> datas, int attributeIndex, int targetAttributeIndex) {160         if (datas == null || datas.isEmpty()) {161             return 0;162         }163 164         Map<String, Integer> targetAttributeCountMap = new HashMap<String, Integer>(); // 类别(目标属性)计数165         Map<String, Integer> featureAttributesCountMap = new HashMap<>(); // 特征属性上的取值计数166         Map<KeyValue, Integer> tfAttributeCountMap = new HashMap<>(); // 类别和特征属性的不同组合的计数167 168         for (String[] arrs : datas) {169             String tv = arrs[targetAttributeIndex];170             String fv = arrs[attributeIndex];171             if (targetAttributeCountMap.containsKey(tv)) {172                 targetAttributeCountMap.put(tv, targetAttributeCountMap.get(tv) + 1);173             } else {174                 targetAttributeCountMap.put(tv, 1);175             }176             if (featureAttributesCountMap.containsKey(fv)) {177                 featureAttributesCountMap.put(fv, featureAttributesCountMap.get(fv) + 1);178             } else {179                 featureAttributesCountMap.put(fv, 1);180             }181             KeyValue key = new KeyValue(fv, tv);182             if (tfAttributeCountMap.containsKey(key)) {183                 tfAttributeCountMap.put(key, tfAttributeCountMap.get(key) + 1);184             } else {185                 tfAttributeCountMap.put(key, 1);186             }187         }188 189         int totalDataSize = datas.size();190         // 计算概率191         List<Double> probabilitys = calcProbability(totalDataSize, targetAttributeCountMap.values());192         // 计算目标属性的信息熵193         double entropyDatas = this.entropy(probabilitys);194 195         // 计算条件概率196         // 第一步,计算目标属性的各种取值,在特征属性确定的条件下的情况197         Map<String, List<Double>> pcda = new HashMap<>();198         for (Map.Entry<KeyValue, Integer> entry : tfAttributeCountMap.entrySet()) {199             String key = entry.getKey().first;200             double pca = 1.0 * entry.getValue() / featureAttributesCountMap.get(key);201             if (pcda.containsKey(key)) {202                 pcda.get(key).add(pca);203             } else {204                 List<Double> list = new ArrayList<Double>();205                 list.add(pca);206                 pcda.put(key, list);207             }208         }209         // 第二步,针对每个特征属性的值取信息熵,并获取平均熵210         double conditionEntropy = 0.0;211         for (Map.Entry<String, Integer> entry : featureAttributesCountMap.entrySet()) {212             double p = 1.0 * entry.getValue() / totalDataSize;213             double e = this.entropy(pcda.get(entry.getKey()));214             conditionEntropy += e * p;215         }216         return entropyDatas - conditionEntropy;217     }218 219     /**220      * 获取数据集中目标属性中,实例值个数最大的目标特征值221      * 222      * @param datas223      * @param targetAttributeIndex224      * @return225      */226     private String getResult(List<String[]> datas, int targetAttributeIndex) {227         if (datas == null || datas.isEmpty()) {228             return null;229         } else {230             String result = "";231             Map<String, Integer> countMap = new HashMap<>();232             for (String[] arr : datas) {233                 String key = arr[targetAttributeIndex];234                 if (countMap.containsKey(key)) {235                     countMap.put(key, countMap.get(key) + 1);236                 } else {237                     countMap.put(key, 1);238                 }239             }240 241             int maxCount = -1;242             for (Map.Entry<String, Integer> entry : countMap.entrySet()) {243                 if (entry.getValue() > maxCount) {244                     maxCount = entry.getValue();245                     result = entry.getKey();246                 }247             }248             return result;249         }250     }251 252     /**253      * 按照特征属性的值将数据集D分裂成为多个子集254      * 255      * @param datas256      *            数据集257      * @param attributeIndex258      *            特征属性下标259      * @return260      */261     private Map<String, List<String[]>> devideDatas(List<String[]> datas, int attributeIndex) {262         Map<String, List<String[]>> subdatas = new HashMap<>();263         if (datas != null && !datas.isEmpty()) {264             for (String[] arr : datas) {265                 String key = arr[attributeIndex];266                 if (subdatas.containsKey(key)) {267                     subdatas.get(key).add(arr);268                 } else {269                     List<String[]> list = new ArrayList<>();270                     list.add(arr);271                     subdatas.put(key, list);272                 }273             }274         }275         return subdatas;276     }277 278     /**279      * 打印决策树280      * 281      * @param tree282      * @param stock283      */284     private void printTree(Node tree, Deque<Object> stock) {285         if (tree == null) {286             return;287         }288 289         if (tree.result != null) {290             this.printBrance(stock, tree.result);291         } else {292             stock.push(this.attributeMap.get(tree.attr));293             for (Map.Entry<String, List<Node>> entry : tree.childs.entrySet()) {294                 stock.push(entry.getKey());295                 for (Node node : entry.getValue()) {296                     this.printTree(node, stock);297                 }298                 stock.pop();299             }300             stock.pop();301         }302     }303 304     /**305      * 输出Node表示的决策树的规则306      * 307      * @param tree308      */309     private void printBrance(Deque<Object> stock, String target) {310         StringBuffer sb = new StringBuffer();311         int odd = 0;312         for (Object e : stock) {313             sb.insert(0, odd == 0 ? "^" : "=").insert(0, e);314             // sb.append(e).append(odd == 0 ? "=" : "^");315             odd = 1 - odd;316         }317         sb.append("target=").append(target);318         System.out.println(sb.toString());319     }320 321     /**322      * 创建一个决策树323      * 324      * @param datas325      * @param attributes326      * @param threshold327      * @param targetIndex328      * @param tree329      * @return330      */331     private Node buildTree(List<String[]> datas, List<Integer> attributes, double threshold, int targetIndex, Node tree) {332         if (tree == null) {333             tree = new Node(datas, null);334         }335         // 分隔数据集,返回的数据集为empty或者是有数据,不会为null336         Map<String, List<String[]>> subDatas = this.devideDatas(datas, targetIndex);337         if (subDatas.size() <= 1) {338             // 这里只会有一个key339             for (String key : subDatas.keySet()) {340                 tree.result = key;341             }342         } else if (attributes == null || attributes.size() < 1) {343             // 没有特征集,那么直接获取最多的值344             tree.result = this.getResult(datas, targetIndex);345         } else {346             double maxGain = 0;347             int attr = 0;348 349             for (int attribute : attributes) {350                 double gain = this.infoGain(datas, attribute, targetIndex);351                 if (gain > maxGain) {352                     maxGain = gain;353                     attr = attribute;// 最大的信息增益下标354                 }355             }356 357             if (maxGain < threshold) {358                 // 达到收益条件359                 tree.result = this.getResult(datas, targetIndex);360             } else {361                 // 没有达到结束条件,继续进行362                 tree.attr = attr;363                 subDatas = this.devideDatas(datas, attr);364                 tree.dataset = null;365                 attributes.remove(Integer.valueOf(attr));366                 for (String key : subDatas.keySet()) {367                     Node childTree = new Node(subDatas.get(key), tree);368                     if (tree.childs.containsKey(key)) {369                         tree.childs.get(key).add(childTree);370                     } else {371                         List<Node> childs = new ArrayList<>();372                         childs.add(childTree);373                         tree.childs.put(key, childs);374                     }375                     this.buildTree(subDatas.get(key), attributes, threshold, targetIndex, childTree);376                 }377             }378         }379         return tree;380     }381 382     /**383      * 根据决策规则获取推荐值384      * 385      * @param instance386      * @return387      */388     private String classify(Node tree, String[] instance) {389         if (tree == null) {390             return null;391         }392         if (tree.result != null) {393             return tree.result;394         }395         if (tree.childs.containsKey(instance[tree.attr])) {396             List<Node> nodes = tree.childs.get(instance[tree.attr]);397             for (Node node : nodes) {398                 return this.classify(node, instance);399             }400         }401         return null;402     }403 404     /**405      * 生产决策树406      */407     public void buildTree() {408         this.tree = new Node(this.datas, null);409         this.buildTree(datas, attributes, threshold, targetIndex, tree);410     }411 412     /**413      * 打印生产的规则414      */415     public void printTree() {416         this.printTree(this.tree, new LinkedList<>());417     }418 419     /**420      * 获取推荐结果421      * 422      * @param instance423      * @return424      */425     public String classify(String[] instance) {426         return this.classify(this.tree, instance);427     }428 429     public static void main(String[] args) {430         List<String[]> dataset = new ArrayList<>();431         dataset.add(new String[] { "1", "youth", "high", "no", "fair", "no" });432         dataset.add(new String[] { "2", "youth", "high", "no", "excellent", "no" });433         dataset.add(new String[] { "3", "middle_aged", "high", "no", "fair", "yes" });434         dataset.add(new String[] { "4", "senior", "medium", "no", "fair", "yes" });435         dataset.add(new String[] { "5", "senior", "low", "yes", "fair", "yes" });436         dataset.add(new String[] { "6", "senior", "low", "yes", "excellent", "no" });437         dataset.add(new String[] { "7", "middle_aged", "low", "yes", "excellent", "yes" });438         dataset.add(new String[] { "8", "youth", "medium", "no", "fair", "no" });439         dataset.add(new String[] { "9", "youth", "low", "yes", "fair", "yes" });440         dataset.add(new String[] { "10", "senior", "medium", "yes", "fair", "yes" });441         dataset.add(new String[] { "11", "youth", "medium", "yes", "excellent", "yes" });442         dataset.add(new String[] { "12", "middle_aged", "medium", "no", "excellent", "yes" });443         dataset.add(new String[] { "13", "middle_aged", "high", "yes", "fair", "yes" });444         dataset.add(new String[] { "14", "senior", "medium", "no", "excellent", "no" });445 446         List<Integer> attributes = new ArrayList<>();447         attributes.add(4);448         attributes.add(1);449         attributes.add(2);450         attributes.add(3);451 452         Map<Integer, String> attributeMap = new HashMap<>();453         attributeMap.put(1, "Age");454         attributeMap.put(2, "Income");455         attributeMap.put(3, "Student");456         attributeMap.put(4, "Credit_rating");457 458         int targetIndex = 5;459 460         String[] instance = new String[] { "15", "youth", "medium", "yes", "fair" };461 462         ID3Tree tree = new ID3Tree(dataset, attributes,attributeMap, targetIndex);463         System.out.println("start build the tree");464         tree.buildTree();465         System.out.println("completed build the tree, start print the tree");466         tree.printTree();467         System.out.println("start classify.....");468         String result = tree.classify(instance);469         System.out.println(result);470     }471 }
ID3--Java

 运行java程序的结果是:

Start build the tree.....Completed build the tree, start print the tree.....Age=youth^Student=yes^target=yesAge=youth^Student=no^target=noAge=middle_aged^target=yesAge=senior^Credit_rating=excellent^target=noAge=senior^Credit_rating=fair^target=yesstart classify.....yes

 五、ID3算法不足 

ID3算法运行速度较慢,只能加载内存中的数据,处理的数据集相对于其他算法较小。

 

ID3算法