首页 > 代码库 > 朴素贝叶斯分类器(离散型)算法实现(一)
朴素贝叶斯分类器(离散型)算法实现(一)
1. 贝叶斯定理:
(1) P(A^B) = P(A|B)P(B) = P(B|A)P(A)
由(1)得
P(A|B) = P(B|A)*P(A)/[p(B)]
贝叶斯在最基本题型:
假定一个场景,在一所高中男女比例为4:6, 留长头发的有男学生有女学生, 我们设定女生都留长发 , 而男生中有10%的留长发,90%留短发.那么如果我们看到远处一个长发背影?请问是一只男学生的概率?
分析:
P(男|长发) = P(长发|男)*P(男)/[p(长发)]
= (1/10)*(4/10)/[(6+4*(1/10))/10]
=1/16 =0.0625
P(女|长发) =P(长发|女)*P(女)/[p(长发)]
=1*(6/10)/[(6+4*(1/10))/10]
=30/32 =15/16
再举一个列子:
某个医院早上收了六个门诊病人,如下表。
症状 职业 疾病
打喷嚏 护士 感冒
打喷嚏 农夫 过敏
头痛 建筑工人 脑震荡
头痛 建筑工人 感冒
打喷嚏 教师 感冒
头痛 教师 脑震荡
现在又来了第七个病人,是一个打喷嚏的建筑工人。请问他患上感冒的概率有多大?(来源: http://www.ruanyifeng.com/blog/2013/12/naive_bayes_classifier.html)
Java代码实现:
1 /** 2 * ********************************************************* 3 * <p/> 4 * Author: XiJun.Gong 5 * Date: 2016-08-31 20:36 6 * Version: default 1.0.0 7 * Class description: 8 * <p>特征库</p> 9 * <p/>10 * *********************************************************11 */12 13 public class FeaturePoint {14 15 private String key;16 private double p;17 18 public FeaturePoint(String key) {19 this(key, 1);20 }21 22 public FeaturePoint(String key, double p) {23 this.key = key;24 this.p = p;25 }26 27 public String getKey() {28 return key;29 }30 31 public void setKey(String key) {32 this.key = key;33 }34 35 public double getP() {36 return p;37 }38 39 public void setP(double p) {40 this.p = p;41 }42 }
1 import com.google.common.collect.ArrayListMultimap; 2 import com.google.common.collect.Multimap; 3 4 import java.util.Collection; 5 import java.util.List; 6 7 /** 8 * ********************************************************* 9 * <p/>10 * Author: XiJun.Gong11 * Date: 2016-08-31 15:4812 * Version: default 1.0.013 * Class description:14 * <p/>15 * *********************************************************16 */17 18 public class Bayes {19 private static Multimap<String, FeaturePoint> map = ArrayListMultimap.create();20 21 /*喂数据*/22 public void input(List<String> labels) {23 24 for (String key : labels) {25 Collection<FeaturePoint> features = map.get(key);26 for (String value : labels) {27 if (features == null || features.size() < 1) {28 map.put(key, new FeaturePoint(value));29 continue;30 }31 boolean tag = false;32 for (FeaturePoint feature : features) {33 if (feature.getKey().equals(value)) {34 Double num = feature.getP() + 1;35 map.remove(key, feature);36 map.put(key, new FeaturePoint(value, num));37 tag = true;38 break;39 }40 }41 if (!tag)42 map.put(key, new FeaturePoint(value));43 }44 }45 }46 47 /*构造模型*/48 public void excute(List<String> labels) {49 // excute(labels, null);50 }51 52 /*构造模型*/53 public Double excute(final List<String> labels, final String judge, Integer dataSize) {54 55 Double denominator = 1d; //分母56 Double numerator = 1d; //分子57 Double coughNum = 0d;58 /*选择相关性分子*/59 Collection<FeaturePoint> featurePoints = map.get(judge);60 for (FeaturePoint featurePoint : featurePoints) {61 if (judge.equals(featurePoint.getKey())) {62 coughNum = featurePoint.getP();63 denominator *= (featurePoint.getP() / dataSize);64 break;65 }66 }67 68 Integer size = featurePoints.size() - 1; //容量69 for (String label : labels) {70 for (FeaturePoint featurePoint : featurePoints) {71 if (label.equals(featurePoint.getKey())) {72 denominator *= (featurePoint.getP() / coughNum);73 for (FeaturePoint feature : map.get(label)) {74 if (label.equals(feature.getKey())) {75 numerator *= (feature.getP() / dataSize);76 }77 }78 }79 }80 }81 82 return denominator / numerator;83 }84 85 }
1 import com.google.common.collect.Lists; 2 3 import java.util.List; 4 import java.util.Scanner; 5 6 /** 7 * ********************************************************* 8 * <p/> 9 * Author: XiJun.Gong10 * Date: 2016-09-01 14:5811 * Version: default 1.0.012 * Class description:13 * <p/>14 * *********************************************************15 */16 public class Main {17 18 public static void main(String args[]) {19 20 Scanner scanner = new Scanner(System.in);21 Integer size = scanner.nextInt();22 Integer row = scanner.nextInt();23 Bayes bayes = new Bayes();24 while (scanner.hasNext()) {25 26 for (int ro = 0; ro < row; ro++) {27 List<String> list = Lists.newArrayList();28 for (int i = 0; i < size; i++) {29 list.add(scanner.next());30 }31 bayes.input(list);32 }33 List<String> list = Lists.newArrayList();34 for (int i = 0; i < size - 1; i++) {35 list.add(scanner.next());36 }37 String judge = scanner.next();38 System.out.println(bayes.excute(list, judge,row));39 ;40 }41 42 }43 }
pom.xml包
<dependency> <groupId>junit</groupId> <artifactId>junit</artifactId> <version>3.8.1</version> <scope>test</scope> </dependency> <dependency> <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> <version>18.0</version> </dependency>
结果:
1 3 62 打喷嚏 护士 感冒 3 打喷嚏 农夫 过敏 4 头痛 建筑工人 脑震荡 5 头痛 建筑工人 感冒 6 打喷嚏 教师 感冒 7 头痛 教师 脑震荡8 打喷嚏 建筑工人 感冒9 0.6666666666666666
1 3 62 打喷嚏 护士 感冒 3 打喷嚏 农夫 过敏 4 头痛 建筑工人 脑震荡 5 头痛 建筑工人 感冒 6 打喷嚏 教师 感冒 7 头痛 教师 脑震荡8 打喷嚏 护士 感冒 9 1.3333333333333333
1 2 50 2 男 长发 3 男 短发 4 男 短发 5 男 短发 6 男 短发 7 男 短发 8 男 短发 9 男 短发10 男 短发11 男 短发12 男 短发13 男 短发14 男 短发15 男 短发16 男 短发17 男 短发18 男 短发19 男 短发20 男 短发21 男 长发22 女 长发23 女 长发24 女 长发25 女 长发26 女 长发27 女 长发28 女 长发29 女 长发30 女 长发31 女 长发32 女 长发33 女 长发34 女 长发35 女 长发36 女 长发37 女 长发38 女 长发39 女 长发40 女 长发41 女 长发42 女 长发43 女 长发44 女 长发45 女 长发46 女 长发47 女 长发48 女 长发49 女 长发50 女 长发51 女 长发52 53 长发 男54 0.06250000000000001
1 2 50 2 男 长发 3 男 短发 4 男 短发 5 男 短发 6 男 短发 7 男 短发 8 男 短发 9 男 短发10 男 短发11 男 短发12 男 短发13 男 短发14 男 短发15 男 短发16 男 短发17 男 短发18 男 短发19 男 短发20 男 短发21 男 长发22 女 长发23 女 长发24 女 长发25 女 长发26 女 长发27 女 长发28 女 长发29 女 长发30 女 长发31 女 长发32 女 长发33 女 长发34 女 长发35 女 长发36 女 长发37 女 长发38 女 长发39 女 长发40 女 长发41 女 长发42 女 长发43 女 长发44 女 长发45 女 长发46 女 长发47 女 长发48 女 长发49 女 长发50 女 长发51 女 长发52 长发 女53 0.9375
利用贝叶斯进行分类?
1 import com.google.common.collect.ArrayListMultimap; 2 import com.google.common.collect.Lists; 3 import com.google.common.collect.Multimap; 4 5 import java.util.Collection; 6 import java.util.List; 7 8 /** 9 * ********************************************************* 10 * <p/> 11 * Author: XiJun.Gong 12 * Date: 2016-08-31 15:48 13 * Version: default 1.0.0 14 * Class description: 15 * <p/> 16 * ********************************************************* 17 */ 18 19 public class Bayes { 20 private Multimap<String, FeaturePoint> map = null; 21 private List<String> featurePool = null; 22 23 public Bayes() { 24 map = ArrayListMultimap.create(); 25 featurePool = Lists.newArrayList(); 26 } 27 28 public void add(String label) { 29 featurePool.add(label); 30 } 31 32 /*喂数据*/ 33 public void input(List<String> labels) { 34 35 for (String key : labels) { 36 Collection<FeaturePoint> features = map.get(key); 37 for (String value : labels) { 38 if (features == null || features.size() < 1) { 39 map.put(key, new FeaturePoint(value)); 40 continue; 41 } 42 boolean tag = false; 43 for (FeaturePoint feature : features) { 44 if (feature.getKey().equals(value)) { 45 Double num = feature.getP() + 1; 46 map.remove(key, feature); 47 map.put(key, new FeaturePoint(value, num)); 48 tag = true; 49 break; 50 } 51 } 52 if (!tag) 53 map.put(key, new FeaturePoint(value)); 54 } 55 } 56 } 57 58 /*最符合那个分类*/ 59 public String excute(List<String> labels, Integer dataSize) { 60 61 Double max = -999999999d; 62 String max_obj = null; 63 List<Double> ans = Lists.newArrayList(); 64 for (String label : featurePool) { 65 Double p = excute(labels, label, dataSize); 66 ans.add(p); 67 if (max < p) { 68 max_obj = label; 69 max = p; 70 } 71 } 72 return max_obj; 73 } 74 75 /*构造模型*/ 76 public Double excute(final List<String> labels, final String judge, Integer dataSize) { 77 78 Double denominator = 1d; //分母 79 Double numerator = 1d; //分子 80 Double coughNum = 0d; 81 /*选择相关性分子*/ 82 Collection<FeaturePoint> featurePoints = map.get(judge); 83 for (FeaturePoint featurePoint : featurePoints) { 84 if (judge.equals(featurePoint.getKey())) { 85 coughNum = featurePoint.getP(); 86 denominator *= (featurePoint.getP() / dataSize); 87 break; 88 } 89 } 90 /*O(n^3)*/ 91 Integer size = featurePoints.size() - 1; //容量 92 for (String label : labels) { 93 for (FeaturePoint featurePoint : featurePoints) { 94 if (label.equals(featurePoint.getKey())) { 95 denominator *= (featurePoint.getP() / coughNum); 96 for (FeaturePoint feature : map.get(label)) { 97 if (label.equals(feature.getKey())) { 98 numerator *= (feature.getP() / dataSize); 99 }100 }101 }102 }103 }104 105 return denominator / numerator;106 }107 108 }
1 import com.google.common.collect.Lists; 2 3 import java.util.List; 4 import java.util.Scanner; 5 6 /** 7 * ********************************************************* 8 * <p/> 9 * Author: XiJun.Gong10 * Date: 2016-09-01 14:5811 * Version: default 1.0.012 * Class description:13 * <p/>14 * *********************************************************15 */16 public class Main {17 18 public static void main(String args[]) {19 20 Scanner scanner = new Scanner(System.in);21 Integer size = scanner.nextInt();22 Integer row = scanner.nextInt();23 Integer category = scanner.nextInt();24 while (scanner.hasNext()) {25 Bayes bayes = new Bayes();26 for (int ro = 0; ro < row; ro++) {27 List<String> list = Lists.newArrayList();28 for (int i = 0; i < size; i++) {29 list.add(scanner.next());30 }31 bayes.input(list);32 }33 List<String> list = Lists.newArrayList();34 for (int i = 0; i < size - 1; i++) {35 list.add(scanner.next());36 }37 for (int i = 0; i < category; i++) {38 bayes.add(scanner.next());39 }40 System.out.println(bayes.excute(list, row));41 }42 43 }44 }
结果:
1 2 50 2 2 男 长发 3 男 短发 4 男 短发 5 男 短发 6 男 短发 7 男 短发 8 男 短发 9 男 短发10 男 短发11 男 短发12 男 短发13 男 短发14 男 短发15 男 短发16 男 短发17 男 短发18 男 短发19 男 短发20 男 短发21 男 长发22 女 长发23 女 长发24 女 长发25 女 长发26 女 长发27 女 长发28 女 长发29 女 长发30 女 长发31 女 长发32 女 长发33 女 长发34 女 长发35 女 长发36 女 长发37 女 长发38 女 长发39 女 长发40 女 长发41 女 长发42 女 长发43 女 长发44 女 长发45 女 长发46 女 长发47 女 长发48 女 长发49 女 长发50 女 长发51 女 长发52 长发53 男 女54 女
1 2 50 2 2 男 长发 3 男 短发 4 男 短发 5 男 短发 6 男 短发 7 男 短发 8 男 短发 9 男 短发10 男 短发11 男 短发12 男 短发13 男 短发14 男 短发15 男 短发16 男 短发17 男 短发18 男 短发19 男 短发20 男 短发21 男 长发22 女 长发23 女 长发24 女 长发25 女 长发26 女 长发27 女 长发28 女 长发29 女 长发30 女 长发31 女 长发32 女 长发33 女 长发34 女 长发35 女 长发36 女 长发37 女 长发38 女 长发39 女 长发40 女 长发41 女 长发42 女 长发43 女 长发44 女 长发45 女 长发46 女 长发47 女 长发48 女 长发49 女 长发50 女 长发51 女 长发52 短发53 男 女54 男
朴素贝叶斯分类器(离散型)算法实现(一)
声明:以上内容来自用户投稿及互联网公开渠道收集整理发布,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任,若内容有误或涉及侵权可进行投诉: 投诉/举报 工作人员会在5个工作日内联系你,一经查实,本站将立刻删除涉嫌侵权内容。