首页 > 代码库 > 朴素贝叶斯分类器(离散型)算法实现(一)

朴素贝叶斯分类器(离散型)算法实现(一)

 

1. 贝叶斯定理:    

   (1)   P(A^B) = P(A|B)P(B) = P(B|A)P(A) 

 由(1)得

   P(A|B) = P(B|A)*P(A)/[p(B)]

 

贝叶斯在最基本题型:

假定一个场景,在一所高中男女比例为4:6, 留长头发的有男学生有女学生, 我们设定女生都留长发 , 而男生中有10%的留长发,90%留短发.那么如果我们看到远处一个长发背影?请问是一只男学生的概率?

  分析:

    P(男|长发) = P(长发|男)*P(男)/[p(长发)] 

        = (1/10)*(4/10)/[(6+4*(1/10))/10]

        =1/16 =0.0625

   P(女|长发) =P(长发|女)*P(女)/[p(长发)]

                  =1*(6/10)/[(6+4*(1/10))/10]

                 =30/32 =15/16

 

再举一个列子:

某个医院早上收了六个门诊病人,如下表。

  症状  职业   疾病

  打喷嚏 护士   感冒 
  打喷嚏 农夫   过敏 
  头痛  建筑工人 脑震荡 
  头痛  建筑工人 感冒 
  打喷嚏 教师   感冒 
  头痛  教师   脑震荡

现在又来了第七个病人,是一个打喷嚏的建筑工人。请问他患上感冒的概率有多大?(来源: http://www.ruanyifeng.com/blog/2013/12/naive_bayes_classifier.html)

Java代码实现:

 1 /** 2  * ********************************************************* 3  * <p/> 4  * Author:     XiJun.Gong 5  * Date:       2016-08-31 20:36 6  * Version:    default 1.0.0 7  * Class description: 8  * <p>特征库</p> 9  * <p/>10  * *********************************************************11  */12 13 public class FeaturePoint {14 15     private String key;16     private double p;17 18     public FeaturePoint(String key) {19         this(key, 1);20     }21 22     public FeaturePoint(String key, double p) {23         this.key = key;24         this.p = p;25     }26 27     public String getKey() {28         return key;29     }30 31     public void setKey(String key) {32         this.key = key;33     }34 35     public double getP() {36         return p;37     }38 39     public void setP(double p) {40         this.p = p;41     }42 }
 1 import com.google.common.collect.ArrayListMultimap; 2 import com.google.common.collect.Multimap; 3  4 import java.util.Collection; 5 import java.util.List; 6  7 /** 8  * ********************************************************* 9  * <p/>10  * Author:     XiJun.Gong11  * Date:       2016-08-31 15:4812  * Version:    default 1.0.013  * Class description:14  * <p/>15  * *********************************************************16  */17 18 public class Bayes {19     private static Multimap<String, FeaturePoint> map = ArrayListMultimap.create();20 21     /*喂数据*/22     public void input(List<String> labels) {23 24         for (String key : labels) {25             Collection<FeaturePoint> features = map.get(key);26             for (String value : labels) {27                 if (features == null || features.size() < 1) {28                     map.put(key, new FeaturePoint(value));29                     continue;30                 }31                 boolean tag = false;32                 for (FeaturePoint feature : features) {33                     if (feature.getKey().equals(value)) {34                         Double num = feature.getP() + 1;35                         map.remove(key, feature);36                         map.put(key, new FeaturePoint(value, num));37                         tag = true;38                         break;39                     }40                 }41                 if (!tag)42                     map.put(key, new FeaturePoint(value));43             }44         }45     }46 47     /*构造模型*/48     public void excute(List<String> labels) {49         //   excute(labels, null);50     }51 52     /*构造模型*/53     public Double excute(final List<String> labels, final String judge, Integer dataSize) {54 55         Double denominator = 1d;    //分母56         Double numerator = 1d;      //分子57         Double coughNum = 0d;58        /*选择相关性分子*/59         Collection<FeaturePoint> featurePoints = map.get(judge);60         for (FeaturePoint featurePoint : featurePoints) {61             if (judge.equals(featurePoint.getKey())) {62                 coughNum = featurePoint.getP();63                 denominator *= (featurePoint.getP() / dataSize);64                 break;65             }66         }67 68         Integer size = featurePoints.size() - 1; //容量69         for (String label : labels) {70             for (FeaturePoint featurePoint : featurePoints) {71                 if (label.equals(featurePoint.getKey())) {72                     denominator *= (featurePoint.getP() / coughNum);73                     for (FeaturePoint feature : map.get(label)) {74                         if (label.equals(feature.getKey())) {75                             numerator *= (feature.getP() / dataSize);76                         }77                     }78                 }79             }80         }81 82         return denominator / numerator;83     }84 85 }

 

 1 import com.google.common.collect.Lists; 2  3 import java.util.List; 4 import java.util.Scanner; 5  6 /** 7  * ********************************************************* 8  * <p/> 9  * Author:     XiJun.Gong10  * Date:       2016-09-01 14:5811  * Version:    default 1.0.012  * Class description:13  * <p/>14  * *********************************************************15  */16 public class Main {17 18     public static void main(String args[]) {19 20         Scanner scanner = new Scanner(System.in);21         Integer size = scanner.nextInt();22         Integer row = scanner.nextInt();23         Bayes bayes = new Bayes();24         while (scanner.hasNext()) {25 26             for (int ro = 0; ro < row; ro++) {27                 List<String> list = Lists.newArrayList();28                 for (int i = 0; i < size; i++) {29                     list.add(scanner.next());30                 }31                 bayes.input(list);32             }33             List<String> list = Lists.newArrayList();34             for (int i = 0; i < size - 1; i++) {35                 list.add(scanner.next());36             }37             String judge = scanner.next();38             System.out.println(bayes.excute(list, judge,row));39             ;40         }41 42     }43 }

pom.xml包

    <dependency>            <groupId>junit</groupId>            <artifactId>junit</artifactId>            <version>3.8.1</version>            <scope>test</scope>        </dependency>        <dependency>            <groupId>com.google.guava</groupId>            <artifactId>guava</artifactId>            <version>18.0</version>        </dependency>

结果:

1 3 62 打喷嚏 护士   感冒 3   打喷嚏 农夫   过敏 4   头痛  建筑工人 脑震荡 5   头痛  建筑工人 感冒 6   打喷嚏 教师   感冒 7   头痛  教师   脑震荡8 打喷嚏  建筑工人 感冒9 0.6666666666666666 
1 3 62   打喷嚏 护士   感冒 3   打喷嚏 农夫   过敏 4   头痛  建筑工人 脑震荡 5   头痛  建筑工人 感冒 6   打喷嚏 教师   感冒 7   头痛  教师   脑震荡8 打喷嚏 护士   感冒 9 1.3333333333333333

 

 1 2 50 2 男  长发 3 男  短发 4 男  短发 5 男  短发 6 男  短发 7 男  短发 8 男  短发 9 男  短发10 男  短发11 男  短发12 男  短发13 男  短发14 男  短发15 男  短发16 男  短发17 男  短发18 男  短发19 男  短发20 男  短发21 男  长发22 女  长发23 女  长发24 女  长发25 女  长发26 女  长发27 女  长发28 女  长发29 女  长发30 女  长发31 女  长发32 女  长发33 女  长发34 女  长发35 女  长发36 女  长发37 女  长发38 女  长发39 女  长发40 女  长发41 女  长发42 女  长发43 女  长发44 女  长发45 女  长发46 女  长发47 女  长发48 女  长发49 女  长发50 女  长发51 女  长发52             53 长发 男54 0.06250000000000001
技术分享
 1 2 50 2 男  长发 3 男  短发 4 男  短发 5 男  短发 6 男  短发 7 男  短发 8 男  短发 9 男  短发10 男  短发11 男  短发12 男  短发13 男  短发14 男  短发15 男  短发16 男  短发17 男  短发18 男  短发19 男  短发20 男  短发21 男  长发22 女  长发23 女  长发24 女  长发25 女  长发26 女  长发27 女  长发28 女  长发29 女  长发30 女  长发31 女  长发32 女  长发33 女  长发34 女  长发35 女  长发36 女  长发37 女  长发38 女  长发39 女  长发40 女  长发41 女  长发42 女  长发43 女  长发44 女  长发45 女  长发46 女  长发47 女  长发48 女  长发49 女  长发50 女  长发51 女  长发52 长发 女53 0.9375
View Code

 利用贝叶斯进行分类?

技术分享
  1 import com.google.common.collect.ArrayListMultimap;  2 import com.google.common.collect.Lists;  3 import com.google.common.collect.Multimap;  4   5 import java.util.Collection;  6 import java.util.List;  7   8 /**  9  * ********************************************************* 10  * <p/> 11  * Author:     XiJun.Gong 12  * Date:       2016-08-31 15:48 13  * Version:    default 1.0.0 14  * Class description: 15  * <p/> 16  * ********************************************************* 17  */ 18  19 public class Bayes { 20     private Multimap<String, FeaturePoint> map = null; 21     private List<String> featurePool = null; 22  23     public Bayes() { 24         map = ArrayListMultimap.create(); 25         featurePool = Lists.newArrayList(); 26     } 27  28     public void add(String label) { 29         featurePool.add(label); 30     } 31  32     /*喂数据*/ 33     public void input(List<String> labels) { 34  35         for (String key : labels) { 36             Collection<FeaturePoint> features = map.get(key); 37             for (String value : labels) { 38                 if (features == null || features.size() < 1) { 39                     map.put(key, new FeaturePoint(value)); 40                     continue; 41                 } 42                 boolean tag = false; 43                 for (FeaturePoint feature : features) { 44                     if (feature.getKey().equals(value)) { 45                         Double num = feature.getP() + 1; 46                         map.remove(key, feature); 47                         map.put(key, new FeaturePoint(value, num)); 48                         tag = true; 49                         break; 50                     } 51                 } 52                 if (!tag) 53                     map.put(key, new FeaturePoint(value)); 54             } 55         } 56     } 57  58     /*最符合那个分类*/ 59     public String excute(List<String> labels, Integer dataSize) { 60  61         Double max = -999999999d; 62         String max_obj = null; 63         List<Double> ans = Lists.newArrayList(); 64         for (String label : featurePool) { 65             Double p = excute(labels, label, dataSize); 66             ans.add(p); 67             if (max < p) { 68                 max_obj = label; 69                 max = p; 70             } 71         } 72         return max_obj; 73     } 74  75     /*构造模型*/ 76     public Double excute(final List<String> labels, final String judge, Integer dataSize) { 77  78         Double denominator = 1d;    //分母 79         Double numerator = 1d;      //分子 80         Double coughNum = 0d; 81        /*选择相关性分子*/ 82         Collection<FeaturePoint> featurePoints = map.get(judge); 83         for (FeaturePoint featurePoint : featurePoints) { 84             if (judge.equals(featurePoint.getKey())) { 85                 coughNum = featurePoint.getP(); 86                 denominator *= (featurePoint.getP() / dataSize); 87                 break; 88             } 89         } 90        /*O(n^3)*/ 91         Integer size = featurePoints.size() - 1; //容量 92         for (String label : labels) { 93             for (FeaturePoint featurePoint : featurePoints) { 94                 if (label.equals(featurePoint.getKey())) { 95                     denominator *= (featurePoint.getP() / coughNum); 96                     for (FeaturePoint feature : map.get(label)) { 97                         if (label.equals(feature.getKey())) { 98                             numerator *= (feature.getP() / dataSize); 99                         }100                     }101                 }102             }103         }104 105         return denominator / numerator;106     }107 108 }
View Code
技术分享
 1 import com.google.common.collect.Lists; 2  3 import java.util.List; 4 import java.util.Scanner; 5  6 /** 7  * ********************************************************* 8  * <p/> 9  * Author:     XiJun.Gong10  * Date:       2016-09-01 14:5811  * Version:    default 1.0.012  * Class description:13  * <p/>14  * *********************************************************15  */16 public class Main {17 18     public static void main(String args[]) {19 20         Scanner scanner = new Scanner(System.in);21         Integer size = scanner.nextInt();22         Integer row = scanner.nextInt();23         Integer category = scanner.nextInt();24         while (scanner.hasNext()) {25             Bayes bayes = new Bayes();26             for (int ro = 0; ro < row; ro++) {27                 List<String> list = Lists.newArrayList();28                 for (int i = 0; i < size; i++) {29                     list.add(scanner.next());30                 }31                 bayes.input(list);32             }33             List<String> list = Lists.newArrayList();34             for (int i = 0; i < size - 1; i++) {35                 list.add(scanner.next());36             }37             for (int i = 0; i < category; i++) {38                 bayes.add(scanner.next());39             }40             System.out.println(bayes.excute(list, row));41         }42 43     }44 }
View Code

结果:

技术分享
 1 2 50 2 2 男  长发 3 男  短发 4 男  短发 5 男  短发 6 男  短发 7 男  短发 8 男  短发 9 男  短发10 男  短发11 男  短发12 男  短发13 男  短发14 男  短发15 男  短发16 男  短发17 男  短发18 男  短发19 男  短发20 男  短发21 男  长发22 女  长发23 女  长发24 女  长发25 女  长发26 女  长发27 女  长发28 女  长发29 女  长发30 女  长发31 女  长发32 女  长发33 女  长发34 女  长发35 女  长发36 女  长发37 女  长发38 女  长发39 女  长发40 女  长发41 女  长发42 女  长发43 女  长发44 女  长发45 女  长发46 女  长发47 女  长发48 女  长发49 女  长发50 女  长发51 女  长发52 长发53 男 女54
View Code
技术分享
 1 2 50 2 2 男  长发 3 男  短发 4 男  短发 5 男  短发 6 男  短发 7 男  短发 8 男  短发 9 男  短发10 男  短发11 男  短发12 男  短发13 男  短发14 男  短发15 男  短发16 男  短发17 男  短发18 男  短发19 男  短发20 男  短发21 男  长发22 女  长发23 女  长发24 女  长发25 女  长发26 女  长发27 女  长发28 女  长发29 女  长发30 女  长发31 女  长发32 女  长发33 女  长发34 女  长发35 女  长发36 女  长发37 女  长发38 女  长发39 女  长发40 女  长发41 女  长发42 女  长发43 女  长发44 女  长发45 女  长发46 女  长发47 女  长发48 女  长发49 女  长发50 女  长发51 女  长发52 短发53 男 女54
View Code

 

朴素贝叶斯分类器(离散型)算法实现(一)