首页 > 代码库 > 文本分类——NaiveBayes
文本分类——NaiveBayes
前面文章已经介绍了朴素贝叶斯算法的原理,这里基于NavieBayes算法对newsgroup文本进行分类測试。
文中代码參考:http://blog.csdn.net/jiangliqing1234/article/details/39642757
主要内容例如以下:
1、newsgroup数据集介绍
数据下载地址:http://download.csdn.net/detail/hjy321686/8057761。
文本中包括20个不同的新闻组,除当中少数文本属于多个新闻组以外,其余的文档都仅仅属于一个新闻组。
2、newsgroup数据预处理
要对文本进行分类,首先要对其进行预处理,预处理主要步骤例如以下:
step1:英文词法分析,取出数字、连字符、标点符号、特殊字符,全部大写字母转换成小写,可用正則表達式:String res[] = line.split("[^a-zA-Z]");
step2:去停用词。过滤对别无价值的词
step3:词根还原stemmer,基于Porter算法
预处理类例如以下:
package com.datamine.NaiveBayes; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.FileWriter; import java.util.ArrayList; /** * Newsgroup文档预处理 * step1:英文词法分析,取出数字、连字符、标点符号、特殊字符,全部大写字母转换成小写。可用正則表達式:String res[] = line.split("[^a-zA-Z]"); * step2:去停用词,过滤对分类无价值的词 * step3:词根还原stemmer。基于Porter算法 * @author Administrator * */ public class DataPreProcess { private static ArrayList<String> stopWordsArray = new ArrayList<String>(); /** * 输入文件的路径。处理数据 * @param srcDir 文件文件夹的绝对路径 * @param desDir 清洗后的文件路径 * @throws Exception */ public void doProcess(String srcDir) throws Exception{ File fileDir = new File(srcDir); if(!fileDir.exists()){ System.out.println("文件不存在!"); return ; } String subStrDir = srcDir.substring(srcDir.lastIndexOf('/')); String dirTarget = srcDir+"/../../processedSample"+subStrDir; File fileTarget = new File(dirTarget); if(!fileTarget.exists()){ //注意processedSample须要先建立文件夹建出来,否则会报错,由于母文件夹不存在 boolean mkdir = fileTarget.mkdir(); } File[] srcFiles = fileDir.listFiles(); for(int i =0 ;i <srcFiles.length;i++){ String fileFullName = srcFiles[i].getCanonicalPath(); //CanonicalPath不可是全路径,并且把..或者.这种符号解析出来。String fileShortName = srcFiles[i].getName(); //文件名称 if(!new File(fileFullName).isDirectory()){ //确认子文件名称不是文件夹,假设是能够再次递归调用 System.out.println("開始预处理:"+fileFullName); StringBuilder stringBuilder = new StringBuilder(); stringBuilder.append(dirTarget+"/"+fileShortName); createProcessFile(fileFullName,stringBuilder.toString()); }else{ fileFullName = fileFullName.replace("\\", "/"); doProcess(fileFullName); } } } /** * 进行文本预处理生成目标文件 * @param srcDir 源文件文件文件夹的绝对路径 * @param targetDir 生成目标文件的绝对路径 * @throws Exception */ private void createProcessFile(String srcDir, String targetDir) throws Exception { FileReader srcFileReader = new FileReader(srcDir); FileWriter targetFileWriter = new FileWriter(targetDir); BufferedReader srcFileBR = new BufferedReader(srcFileReader); String line,resLine; while((line = srcFileBR.readLine()) != null){ resLine = lineProcess(line); if(!resLine.isEmpty()){ //按行写。一行写一个单词 String[] tempStr = resLine.split(" "); for(int i =0; i<tempStr.length ;i++){ if(!tempStr[i].isEmpty()) targetFileWriter.append(tempStr[i]+"\n"); } } } targetFileWriter.flush(); targetFileWriter.close(); srcFileReader.close(); srcFileBR.close(); } /** * 对每行字符串进行处理,主要是词法分析、去停用词和stemming(去除时态) * @param line 待处理的一行字符串 * @param stopWordsArray 停用词数组 * @return String 处理好的一行字符串,是由处理好的单词又一次生成,以空格为分隔符 */ private String lineProcess(String line) { /* * step1 * 英文词法分析,去除数字、连字符、标点符号、特殊字符, * 全部大写字符转换成小写,能够考虑使用正則表達式 */ String res[] = line.split("[^a-zA-Z]"); //step2 去停用词,大写转换成小写 //step3 Stemmer.run() String resString = new String(); for(int i=0;i<res.length;i++){ if(!res[i].isEmpty() && !stopWordsArray.contains(res[i].toLowerCase())) resString += " " + Stemmer.run(res[i].toLowerCase()) + " "; } return resString; } /** * 用stopWordsArray构造停用词的ArrayList容器 * @param stopwordsPath * @throws Exception */ private static void stopWordsToArray(String stopwordsPath) throws Exception { FileReader stopWordsReader = new FileReader(stopwordsPath); BufferedReader stopWordsBR = new BufferedReader(stopWordsReader); String stopWordsLine = null; //用stopWordsArray构造停用词的ArrayList容器 while((stopWordsLine = stopWordsBR.readLine()) != null){ if(!stopWordsLine.isEmpty()) stopWordsArray.add(stopWordsLine); } stopWordsReader.close(); stopWordsBR.close(); } public static void main(String[] args) throws Exception{ DataPreProcess dataPrePro = new DataPreProcess(); String srcDir = "E:/DataMiningSample/orginSample"; String stopwordsPath = "E:/DataMiningSample/stopwords.txt"; stopWordsToArray(stopwordsPath); dataPrePro.doProcess(srcDir); } }
对于step3中的Porter算法能够网上下载,这里我基于其之上加入了一个run()方法。
/** * Stemmer中接口,将传入的word进行词根还原 * @param word 传入单词 * @return result 处理后的单词 */ public static String run(String word){ Stemmer s = new Stemmer(); char[] ch = word.toCharArray(); for (int c = 0; c < ch.length; c++) s.add(ch[c]); s.stem(); { String u; u = s.toString(); //System.out.print(u); return u; } }
3、特征项选择
方法一:保留全部词作为特征词
方法二:选取出现频率大于某一个数(3或者其它)的词作为特征词
方法三:计算每一个词的权重tf*idf,依据权重来选取特征词
本文中选取方法二。
4、文本向量化
因为本文中。特征词选择採用的是方法二,能够不用对文本进行向量化,可是统计特征词出现的次数方法写在ComputeWordsVector类中,为了程序执行这里还是把文本向量化的代码贴出来。后面使用KNN算法的时候也是要用到此类的。
package com.datamine.NaiveBayes; import java.io.*; import java.util.*; /** * 计算文档的属性向量。将全部文档向量化 * @author Administrator */ public class ComputeWordsVector { /** * 计算文档的TF属性向量。TFPerDocMap * 计算TF*IDF * @param strDir 处理好的newsgroup文件文件夹的绝对路径 * @param trainSamplePercent 训练样本集占每一个类目的比例 * @param indexOfSample 測试例子集的起始的測试例子编号 凝视:通过这个參数能够将文本分成训练和測试两部分 * @param iDFPerWordMap 每一个词的IDF权值属性向量 * @param wordMap 属性词典map * @throws IOException */ public void computeTFMultiIDF(String strDir,double trainSamplePercent,int indexOfSample, Map<String, Double> iDFPerWordMap,Map<String,Double> wordMap) throws IOException{ File fileDir = new File(strDir); String word; SortedMap<String,Double> TFPerDocMap = new TreeMap<String, Double>(); //注意能够用两个写文件,一个专门写測试例子,一个专门写训练例子,用sampleType的值来表示 String trainFileDir = "E:/DataMiningSample/docVector/wordTFIDFMapTrainSample"+indexOfSample; String testFileDir = "E:/DataMiningSample/docVector/wordTFIDFMapTestSample"+indexOfSample; FileWriter tsTrainWriter = new FileWriter(new File(trainFileDir)); //往训练文件里写 FileWriter tsTestWriter = new FileWriter(new File(testFileDir)); //往測试文件里写 FileWriter tsWriter = null; File[] sampleDir = fileDir.listFiles(); for(int i = 0;i<sampleDir.length;i++){ String cateShortName = sampleDir[i].getName(); System.out.println("開始计算: " + cateShortName); File[] sample = sampleDir[i].listFiles(); //測试例子集起始文件序号 double testBeginIndex = indexOfSample*(sample.length*(1-trainSamplePercent)); //測试例子集的结束文件序号 double testEndIndex = (indexOfSample+1)*(sample.length*(1-trainSamplePercent)); System.out.println("文件名称_文件数 :" + sampleDir[i].getCanonicalPath()+"_"+sample.length); System.out.println("训练数:"+sample.length*trainSamplePercent + " 測试文本開始下标:"+ testBeginIndex+" 測试文本结束下标:"+testEndIndex); for(int j =0;j<sample.length; j++){ //计算TF,即每一个词在该文件里出现的频率 TFPerDocMap.clear(); FileReader samReader = new FileReader(sample[j]); BufferedReader samBR = new BufferedReader(samReader); String fileShortName = sample[j].getName(); Double wordSumPerDoc = 0.0;//计算每篇文档的总字数 while((word = samBR.readLine()) != null){ if(!word.isEmpty() && wordMap.containsKey(word)){ wordSumPerDoc++; if(TFPerDocMap.containsKey(word)) TFPerDocMap.put(word, TFPerDocMap.get(word)+1); else TFPerDocMap.put(word, 1.0); } } samBR.close(); /* * 遍历 TFPerDocMap,除以文档的总词数wordSumPerDoc 则得到TF * TF*IDF得到终于的特征权值,并输出到文件 * 注意:測试例子和训练例子写入的文件不同 */ if(j >= testBeginIndex && j <= testEndIndex) tsWriter = tsTestWriter; else tsWriter = tsTrainWriter; Double wordWeight; Set<Map.Entry<String, Double>> tempTF = TFPerDocMap.entrySet(); for(Iterator<Map.Entry<String, Double>> mt = tempTF.iterator();mt.hasNext();){ Map.Entry<String, Double> me = mt.next(); //因为计算IDF很耗时,3万多个词的属性词典初步预计须要25个小时,先尝试觉得全部词的IDF都是1的情况 //wordWeight = (me.getValue() / wordSumPerDoc) * iDFPerWordMap.get(me.getKey()); wordWeight = (me.getValue() / wordSumPerDoc) * 1.0; TFPerDocMap.put(me.getKey(), wordWeight); } tsWriter.append(cateShortName + " "); tsWriter.append(fileShortName + " "); Set<Map.Entry<String, Double>> tempTF2 = TFPerDocMap.entrySet(); for(Iterator<Map.Entry<String, Double>> mt = tempTF2.iterator();mt.hasNext();){ Map.Entry<String, Double> me = mt.next(); tsWriter.append(me.getKey() + " " + me.getValue()+" "); } tsWriter.append("\n"); tsWriter.flush(); } } tsTrainWriter.close(); tsTestWriter.close(); tsWriter.close(); } /** * 统计每一个词的总出现次数。返回出现次数大于3词的词汇构成终于的属性词典 * @param strDir 处理好的newsgroup文件文件夹的绝对路径 * @param wordMap 记录出现的每一个词构成的属性词典 * @return newWordMap 返回出现次数大于3次的词汇构成终于的属性词典 * @throws IOException */ public SortedMap<String, Double> countWords(String strDir, Map<String, Double> wordMap) throws IOException { File sampleFile = new File(strDir); File[] sample = sampleFile.listFiles(); String word; for(int i =0 ;i < sample.length;i++){ if(!sample[i].isDirectory()){ FileReader samReader = new FileReader(sample[i]); BufferedReader samBR = new BufferedReader(samReader); while((word = samBR.readLine()) != null){ if(!word.isEmpty() && wordMap.containsKey(word)) wordMap.put(word, wordMap.get(word)+1); else wordMap.put(word, 1.0); } samBR.close(); }else{ countWords(sample[i].getCanonicalPath(),wordMap); } } /* * 仅仅返回出现次数大于3的单词 * 这里为了简单,应该独立一个函数。避免多次执行 */ SortedMap<String,Double> newWordMap = new TreeMap<String, Double>(); Set<Map.Entry<String, Double>> allWords = wordMap.entrySet(); for(Iterator<Map.Entry<String, Double>> it = allWords.iterator();it.hasNext();){ Map.Entry<String, Double> me = it.next(); if(me.getValue() > 2) newWordMap.put(me.getKey(), me.getValue()); } System.out.println("newWordMap "+ newWordMap.size()); return newWordMap; } /** * 打印属性词典,到allDicWordCountMap.txt中 * @param wordMap 属性词典 * @throws IOException */ public void printWordMap(Map<String, Double> wordMap) throws IOException{ System.out.println("printWordMap:"); int countLine = 0; File outPutFile = new File("E:/DataMiningSample/docVector/allDicWordCountMap.txt"); FileWriter outPutFileWriter = new FileWriter(outPutFile); Set<Map.Entry<String, Double>> allWords = wordMap.entrySet(); for(Iterator<Map.Entry<String, Double>> it = allWords.iterator();it.hasNext();){ Map.Entry<String, Double> me = it.next(); outPutFileWriter.write(me.getKey()+" "+me.getValue()+"\n"); countLine++; } outPutFileWriter.close(); System.out.println("WordMap size : " + countLine); } /** * 词w在整个文档集合中的逆向文档频率idf (Inverse Document Frequency), * 即文档总数n与词w所出现文件数docs(w, D)比值的对数: idf = log(n / docs(w, D)) * 计算IDF。即属性词典中每一个词在多少个文档中出现过 * @param strDir 处理好的newsgroup文件文件夹的绝对路径 * @param wordMap 属性词典 * @return 单词的IDFMap * @throws IOException */ public SortedMap<String,Double> computeIDF(String strDir,Map<String, Double> wordMap) throws IOException{ File fileDir = new File(strDir); String word; SortedMap<String,Double> IDFPerWordMap = new TreeMap<String, Double>(); Set<Map.Entry<String, Double>> wordMapSet = wordMap.entrySet(); for(Iterator<Map.Entry<String, Double>> it = wordMapSet.iterator();it.hasNext();){ Map.Entry<String, Double> pe = it.next(); Double countDoc = 0.0; //出现字典词的文本数 Double sumDoc = 0.0; //文本总数 String dicWord = pe.getKey(); File[] sampleDir = fileDir.listFiles(); for(int i =0;i<sampleDir.length;i++){ File[] sample = sampleDir[i].listFiles(); for(int j = 0;j<sample.length;j++){ sumDoc++; //统计文本数 FileReader samReader = new FileReader(sample[j]); BufferedReader samBR = new BufferedReader(samReader); boolean isExist = false; while((word = samBR.readLine()) != null){ if(!word.isEmpty() && word.equals(dicWord)){ isExist = true; break; } } if(isExist) countDoc++; samBR.close(); } } //计算单词的IDF //double IDF = Math.log(sumDoc / countDoc) / Math.log(10); double IDF = Math.log(sumDoc / countDoc); IDFPerWordMap.put(dicWord, IDF); } return IDFPerWordMap; } public static void main(String[] args) throws IOException { ComputeWordsVector wordsVector = new ComputeWordsVector(); String strDir = "E:\\DataMiningSample\\processedSample"; Map<String, Double> wordMap = new TreeMap<String, Double>(); //属性词典 Map<String, Double> newWordMap = new TreeMap<String, Double>(); newWordMap = wordsVector.countWords(strDir,wordMap); //wordsVector.printWordMap(newWordMap); //wordsVector.computeIDF(strDir, newWordMap); double trainSamplePercent = 0.8; int indexOfSample = 1; Map<String, Double> iDFPerWordMap = null; wordsVector.computeTFMultiIDF(strDir, trainSamplePercent, indexOfSample, iDFPerWordMap, newWordMap); //test(); } public static void test(){ double sumDoc = 18828.0; double countDoc = 229.0; double IDF1 = Math.log(sumDoc / countDoc) / Math.log(10); double IDF2 = Math.log(sumDoc / countDoc) ; System.out.println(IDF1); System.out.println(IDF2); System.out.println(Math.log(10)); } }
5、对文本分为測试集和训练集
按指定的比例(0.9或者0.8)对整个文本进行划分。測试集和训练集
package com.datamine.NaiveBayes; import java.io.*; import java.util.*; public class CreateTrainAndTestSample { void filterSpecialWords() throws IOException{ String word; ComputeWordsVector cwv = new ComputeWordsVector(); String fileDir = "E:\\DataMiningSample\\processedSample"; SortedMap<String, Double> wordMap = new TreeMap<String, Double>(); wordMap = cwv.countWords(fileDir, wordMap); cwv.printWordMap(wordMap); //把wordMap输出到文件 File[] sampleDir = new File(fileDir).listFiles(); for(int i = 0;i<sampleDir.length;i++){ File[] sample = sampleDir[i].listFiles(); String targetDir = "E:/DataMiningSample/processedSampleOnlySpecial/"+sampleDir[i].getName(); File targetDirFile = new File(targetDir); if(!targetDirFile.exists()){ targetDirFile.mkdir(); } for(int j = 0; j<sample.length;j++){ String fileShortName = sample[j].getName(); targetDir = "E:/DataMiningSample/processedSampleOnlySpecial/"+sampleDir[i].getName()+"/"+fileShortName; FileWriter tgWriter = new FileWriter(targetDir); FileReader samReader = new FileReader(sample[j]); BufferedReader samBR = new BufferedReader(samReader); while((word = samBR.readLine()) != null){ if(wordMap.containsKey(word)) tgWriter.append(word+"\n"); } tgWriter.flush(); tgWriter.close(); samBR.close(); } } } /** * 创建训练集和測试集 * @param fileDir 预处理好的文件路径 E:\DataMiningSample\processedSampleOnlySpecial * @param trainSamplePercent 训练集占的百分比0.8 * @param indexOfSample 一个測试集计算规则 1 * @param classifyResultFile 測试例子正确类目记录文件 * @throws IOException */ void createTestSample(String fileDir,double trainSamplePercent,int indexOfSample,String classifyResultFile) throws IOException{ String word,targetDir; FileWriter crWriter = new FileWriter(classifyResultFile);//測试例子正确类目记录文件 File[] sampleDir = new File(fileDir).listFiles(); for(int i =0;i<sampleDir.length;i++){ File[] sample = sampleDir[i].listFiles(); double testBeginIndex = indexOfSample*(sample.length*(1-trainSamplePercent)); double testEndIndex = (indexOfSample + 1)*(sample.length*(1-trainSamplePercent)); for(int j = 0;j<sample.length;j++){ FileReader samReader = new FileReader(sample[j]); BufferedReader samBR = new BufferedReader(samReader); String fileShortName = sample[j].getName(); String subFileName = fileShortName; if(j > testBeginIndex && j < testEndIndex){ targetDir = "E:/DataMiningSample/TestSample"+indexOfSample+"/"+sampleDir[i].getName(); crWriter.append(subFileName + " "+sampleDir[i].getName()+"\n"); }else{ targetDir = "E:/DataMiningSample/TrainSample"+indexOfSample+"/"+sampleDir[i].getName(); } targetDir = targetDir.replace("\\", "/"); File trainSamFile = new File(targetDir); if(!trainSamFile.exists()){ trainSamFile.mkdir(); } targetDir += "/" + subFileName; FileWriter tsWriter = new FileWriter(new File(targetDir)); while((word = samBR.readLine()) != null) tsWriter.append(word+"\n"); tsWriter.flush(); tsWriter.close(); samBR.close(); } } crWriter.close(); } public static void main(String[] args) throws IOException { CreateTrainAndTestSample test = new CreateTrainAndTestSample(); String fileDir = "E:/DataMiningSample/processedSampleOnlySpecial"; double trainSamplePercent=0.8; int indexOfSample=1; String classifyResultFile="E:/DataMiningSample/classifyResult"; test.createTestSample(fileDir, trainSamplePercent, indexOfSample, classifyResultFile); //test.filterSpecialWords(); } }
6、朴素贝叶斯算法描写叙述和实现
类条件概率P(tk|c)=(类c下单词tk在各个文档中出现过的次数之和+1)/(类c下单词总数+训练样本中不反复特征词总数)
先验概率P(c)=类c下的单词总数/整个训练样本的单词总数
(2)伯努利模型(Bernoulli model) –以文件为粒度,或者说是採用二项分布模型,伯努利实验即N次独立反复随机实验,仅仅考虑事件发生/不发生,所以每一个单词的表示变量是布尔型的
类条件概率P(tk|c)=(类c下包括单词tk的文件数+1)/(类c下文件总数+2)
先验概率P(c)=类c下文件总数/整个训练样本的文件总数
本分类器选用多元分布模型计算。依据《Introduction to Information Retrieval 》,多元分布模型计算准确率更高
(2) 用交叉验证法做十次分类实验,对准确率取平均值
(3) 依据正确类目文件和分类结果文计算混淆矩阵而且输出
(4) Map<String,Double> cateWordsProb key为“类目_单词”, value为该类目下该单词的出现次数。避免反复计算
package com.datamine.NaiveBayes; import java.io.BufferedReader; import java.io.File; import java.io.FileNotFoundException; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; import java.math.BigDecimal; import java.util.Iterator; import java.util.Map; import java.util.Set; import java.util.SortedSet; import java.util.TreeMap; import java.util.TreeSet; import java.util.Vector; /** * 利用朴素贝叶斯算法对newsgroup文档集做分类,採用十组交叉測试取平均值 * 採用多项式模型 * 类条件概率 P(tk|c)=(类c下 单词tk 在各个文档中出现过的次数之和 + 1)/(类c下单词的总数 + 训练集总单词数) * @author Administrator */ public class NaiveBayesianClassifier { /** * 用朴素贝叶斯算法对測试文档集分类 * @param trainDir 训练文档集文件夹 * @param testDir 測试文档集文件夹 * @param classifyResultFileNew 分类结果文件路径 * @throws Exception */ private void doProcess(String trainDir,String testDir, String classifyResultFileNew) throws Exception{ //保存训练集中每一个类别的总词数 <文件夹名。单词总数> category Map<String,Double> cateWordsNum = new TreeMap<String, Double>(); //保存训练样本中每一个类别中每一个属性词的出现次数 <类目_单词,数目> Map<String,Double> cateWordsProb = new TreeMap<String, Double>(); cateWordsNum = getCateWordsNum(trainDir); cateWordsProb = getCateWordsProb(trainDir); double totalWordsNum = 0.0;//记录全部训练集的总词数 Set<Map.Entry<String, Double>> cateWordsNumSet = cateWordsNum.entrySet(); for(Iterator<Map.Entry<String, Double>> it = cateWordsNumSet.iterator();it.hasNext();){ Map.Entry<String, Double> me = it.next(); totalWordsNum += me.getValue(); } //以下開始读取測试例子做分类 Vector<String> testFileWords = new Vector<String>(); //測试样本全部词的容器 String word; File[] testDirFiles = new File(testDir).listFiles(); FileWriter crWriter = new FileWriter(classifyResultFileNew); for(int i =0;i<testDirFiles.length;i++){ File[] testSample = testDirFiles[i].listFiles(); for(int j =0;j<testSample.length;j++){ testFileWords.clear(); FileReader spReader = new FileReader(testSample[j]); BufferedReader spBR = new BufferedReader(spReader); while((word = spBR.readLine()) != null){ testFileWords.add(word); } spBR.close(); //以下分别计算该測试例子属于二十个类别的概率 File[] trainDirFiles = new File(trainDir).listFiles(); BigDecimal maxP = new BigDecimal(0); String bestCate = null; for(int k =0; k < trainDirFiles.length; k++){ BigDecimal p = computeCateProb(trainDirFiles[k],testFileWords,cateWordsNum,totalWordsNum,cateWordsProb); if( k == 0){ maxP = p; bestCate = trainDirFiles[k].getName(); continue; } if(p.compareTo(maxP) == 1){ maxP = p; bestCate = trainDirFiles[k].getName(); } } crWriter.append(testSample[j].getName() + " " + bestCate + "\n"); crWriter.flush(); } } crWriter.close(); } /** * 类条件概率 P(tk|c)=(类c下 单词tk 在各个文档中出现过的次数之和 + 1)/(类c下单词的总数 + 训练集中总单词数) * 计算某一个測试样本数据某个类别的概率 使用多项式模型 * @param trainFile 该类别全部的训练样本所在的文件夹 * @param testFileWords 该測试样本中的全部词构成的容器 * @param cateWordsNum 记录每一个文件夹下单词的总数 * @param totalWordsNum 全部训练样本的单词的总数 * @param cateWordsProb 记录每一个文件夹中出现单词和次数 * @return 返回该測试样本在该类别中的概率 */ private BigDecimal computeCateProb(File trainFile, Vector<String> testFileWords, Map<String, Double> cateWordsNum, double totalWordsNum, Map<String, Double> cateWordsProb) { BigDecimal probability = new BigDecimal(1); double wordNumInCate = cateWordsNum.get(trainFile.getName()); BigDecimal wordNumInCateBD = new BigDecimal(wordNumInCate); BigDecimal totalWordsNumBD = new BigDecimal(totalWordsNum); for(Iterator<String> it = testFileWords.iterator();it.hasNext();){ String me = it.next(); String key = trainFile.getName()+"_"+me; double testFileWordNumInCate; if(cateWordsProb.containsKey(key)) testFileWordNumInCate = cateWordsProb.get(key); else testFileWordNumInCate = 0.0; BigDecimal testFileWordNumInCateBD = new BigDecimal(testFileWordNumInCate); BigDecimal xcProb = (testFileWordNumInCateBD.add(new BigDecimal(0.0001))) .divide(wordNumInCateBD.add(totalWordsNumBD), 10, BigDecimal.ROUND_CEILING); probability = probability.multiply(xcProb); } // P = P(tk|c)*P(C) BigDecimal result = probability.multiply(wordNumInCateBD.divide(totalWordsNumBD,10, BigDecimal.ROUND_CEILING)); return result; } /** * 统计某个类训练样本中每一个单词出现的次数 * @param trainDir 训练样本集文件夹 * @return cateWordsProb 用"类目_单词"来索引map,value就是该类目下该单词出现的次数 * @throws Exception */ private Map<String, Double> getCateWordsProb(String trainDir) throws Exception { Map<String,Double> cateWordsProb = new TreeMap<String, Double>(); File sampleFile = new File(trainDir); File[] sampleDir = sampleFile.listFiles(); String word; for(int i =0;i < sampleDir.length;i++){ File[] sample = sampleDir[i].listFiles(); for(int j =0;j<sample.length;j++){ FileReader samReader = new FileReader(sample[j]); BufferedReader samBR = new BufferedReader(samReader); while((word = samBR.readLine()) != null){ String key = sampleDir[i].getName()+"_"+word; if(cateWordsProb.containsKey(key)) cateWordsProb.put(key, cateWordsProb.get(key)+1); else cateWordsProb.put(key, 1.0); } samBR.close(); } } return cateWordsProb; } /** * 获得每一个类目下的单词总数 * @param trainDir 训练文档集文件夹 * @return cateWordsNum <文件夹名,单词总数>的map * @throws IOException */ private Map<String, Double> getCateWordsNum(String trainDir) throws IOException { Map<String, Double> cateWordsNum = new TreeMap<String, Double>(); File[] sampleDir = new File(trainDir).listFiles(); for(int i =0;i<sampleDir.length;i++){ double count = 0; File[] sample = sampleDir[i].listFiles(); for(int j =0;j<sample.length;j++){ FileReader spReader = new FileReader(sample[j]); BufferedReader spBR = new BufferedReader(spReader); while(spBR.readLine() != null){ count++; } spBR.close(); } cateWordsNum.put(sampleDir[i].getName(), count); } return cateWordsNum; } /** * 依据正确类目文件和分类结果文件统计出准确率 * @param classifyRightCate 正确类目文件 <文件名称。类别文件夹名> * @param classifyResultFileNew 分类结果文件 <文件名称,类别文件夹名> * @return 分类的准确率 * @throws Exception */ public double computeAccuracy(String classifyRightCate, String classifyResultFileNew) throws Exception { Map<String,String> rightCate = new TreeMap<String, String>(); Map<String,String> resultCate = new TreeMap<String,String>(); rightCate = getMapFromResultFile(classifyRightCate); resultCate = getMapFromResultFile(classifyResultFileNew); Set<Map.Entry<String, String>> resCateSet = resultCate.entrySet(); double rightCount = 0.0; for(Iterator<Map.Entry<String, String>> it = resCateSet.iterator();it.hasNext();){ Map.Entry<String, String> me = it.next(); if(me.getValue().equals(rightCate.get(me.getKey()))) rightCount++; } computerConfusionMatrix(rightCate,resultCate); return rightCount / resultCate.size(); } /** * 依据正确类目文件和分类结果文件计算混淆矩阵并输出 * @param rightCate 正确类目map * @param resultCate 分类结果相应map */ private void computerConfusionMatrix(Map<String, String> rightCate, Map<String, String> resultCate) { int[][] confusionMatrix = new int[20][20]; //首先求出类目相应的数组索引 SortedSet<String> cateNames = new TreeSet<String>(); Set<Map.Entry<String, String>> rightCateSet = rightCate.entrySet(); for(Iterator<Map.Entry<String, String>> it = rightCateSet.iterator();it.hasNext();){ Map.Entry<String, String> me = it.next(); cateNames.add(me.getValue()); } cateNames.add("rec.sport.baseball");//防止数少一个类目 String[] cateNamesArray = cateNames.toArray(new String[0]); Map<String,Integer> cateNamesToIndex = new TreeMap<String, Integer>(); for(int i =0;i<cateNamesArray.length;i++){ cateNamesToIndex.put(cateNamesArray[i], i); } for(Iterator<Map.Entry<String, String>> it = rightCateSet.iterator();it.hasNext();){ Map.Entry<String, String> me = it.next(); confusionMatrix[cateNamesToIndex.get(me.getValue())][cateNamesToIndex.get(resultCate.get(me.getKey()))]++; } //输出混淆矩阵 double[] hangSum = new double[20]; System.out.print(" "); for(int i=0;i<20;i++){ System.out.printf("%-6d",i); } System.out.println("准确率"); for(int i =0;i<20;i++){ System.out.printf("%-6d",i); for(int j = 0;j<20;j++){ System.out.printf("%-6d",confusionMatrix[i][j]); hangSum[i] += confusionMatrix[i][j]; } System.out.printf("%-6f\n",confusionMatrix[i][i]/hangSum[i]); } System.out.println(); } /** * 从结果文件里读取Map * @param file 类目文件 * @return Map<String,String> 由<文件名称,类目名>保存的map * @throws Exception */ private Map<String, String> getMapFromResultFile(String file) throws Exception { File crFile = new File(file); FileReader crReader = new FileReader(crFile); BufferedReader crBR = new BufferedReader(crReader); Map<String,String> res = new TreeMap<String, String>(); String[] s; String line; while((line = crBR.readLine()) != null){ s = line.split(" "); res.put(s[0], s[1]); } return res; } public static void main(String[] args) throws Exception { CreateTrainAndTestSample ctt = new CreateTrainAndTestSample(); NaiveBayesianClassifier nbClassifier = new NaiveBayesianClassifier(); //依据包括非特征词的文档集生成仅仅包括特征词的文档集到processedSampleOnlySpecial文件夹下 ctt.filterSpecialWords(); double[] accuracyOfEveryExp = new double[10]; double accuracyAvg,sum = 0; for(int i =0;i<10;i++){//用交叉验证法做十次分类实验。对准确率取平均值 String TrainDir = "E:/DataMiningSample/TrainSample"+i; String TestDir = "E:/DataMiningSample/TestSample"+i; String classifyRightCate = "E:/DataMiningSample/classifyRightCate"+i+".txt"; String classifyResultFileNew = "E:/DataMiningSample/classifyResultNew"+i+".txt"; ctt.createTestSample("E:/DataMiningSample/processedSampleOnlySpecial", 0.8, i, classifyRightCate); nbClassifier.doProcess(TrainDir, TestDir, classifyResultFileNew); accuracyOfEveryExp[i] = nbClassifier.computeAccuracy(classifyRightCate,classifyResultFileNew); System.out.println("The accuracy for Naive Bayesian Classifier in "+i+"th Exp is :" + accuracyOfEveryExp[i]); } for(int i =0;i<10;i++) sum += accuracyOfEveryExp[i]; accuracyAvg = sum/10; System.out.println("The average accuracy for Naive Bayesian Classifier in all Exps is :" + accuracyAvg); } }
7、实验结果与说明
这里仅仅列出第一次运行的结果:
这里使用的多项式模型是经过改进的计算方法:改进多项式模型的类条件概率的计算公式,改进为 类条件概率P(tk|c)=(类c下单词tk在各个文档中出现过的次数之和+0.001)/(类c下单词总数+训练样本中不反复特征词总数),分子当tk没有出现时,仅仅加0.001,这样更加精确的描写叙述的词的统计分布规律
8、算法改进
为了进一步提高朴素贝叶斯算法的分类能够进行例如以下改进:
1、优化特征词选取的方法,如方法三,或者更优方法
2、改进多项式模型的类条件概率的计算公式(上面已经实现)
文本分类——NaiveBayes