首页 > 代码库 > 基于的朴素贝叶斯的文本分类(附完整代码(spark/java)

基于的朴素贝叶斯的文本分类(附完整代码(spark/java)

本文主要包括以下内容:
1)模型训练数据生成(demo)
2 ) 模型训练(spark+java),数据存储在hdfs上
3)预测数据生成(demo)
4)使用生成的模型进行文本分类。

一、训练数据生成

spark mllib模型训练的输入数据格式通常有两种,一种叫做 LIBSVM 格式,样式如下:
label index1:value1 index2:value2
label为类别标签,indexX为特征向量索引下标,value为对应的那维的取值。
另一种格式样式如下:
label f1,f2,f3,…,fn
fx为特征取值
两种格式的文件,分别可以通过方法:
org.apache.spark.mllib.util.MLUtils.loadLibSVMFile
org.apache.spark.mllib.util.MLUtils.loadLabeledData
读取。
我们这里采用第一种格式。

现在开始正式生成这种格式的数据文件。在模型训练阶段,会直接从这个文件中读取数据训练。这个后面会讲到。
我们这里假设对于文本,我们已经提取了关键词作为特征。特征列表如下:
features = [w1,w2,w3,…,wn]
同时,文本的主题(类别)集为:
topics = [t1,t2,…tm]
然后有很多的经过简单处理(分词,去停用词等)得到训练数据,每行的格式大概如下:
t2 w1,w23,w34,w1,…
我们直接将词频当做特征的取值。下面是生成libsvm格式的python代码,仅供参考:

    for info in result:
        sstr = ""
        topic_name = info[0] #主题名
        content = str(info[1]).split() #处理后的文本内容(词列表)
        index = topics.index(channel_name)
        sstr += str(index)
        features_val = ""
        for i,word in enumerate(features):
            freq = content.count(word)
            if freq:
                features_val += " "
                features_val += str(i+1) + ":" + str(freq)
        if not features_val:continue
        sstr += features_val
        #bayes_data.write(sstr+"\n")
        print n
        n += 1
    hdfs_client.write_list(sstr_lst,BAYES_DATA_PATH)#写到hdfs指定路径        
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

这样,训练数据就算生成完成了。

二、模型训练

这步就简单了。spark官网上有例子,直接拿来用就行了。现在贴出略做调整后的java代码:

public static void training(JavaSparkContext jsc){
        String path = "data/libsvm_data.txt";

         JavaRDD<LabeledPoint> inputData = http://www.mamicode.com/MLUtils.loadLibSVMFile(jsc.sc(), path).toJavaRDD();"hljs-keyword">new double[]{0.6, 0.4}, 12345);
         JavaRDD<LabeledPoint> training = tmp[0]; // training set
         JavaRDD<LabeledPoint> test = tmp[1]; // test set
         final NaiveBayesModel model = NaiveBayes.train(training.rdd());
         JavaPairRDD<Double, Double> predictionAndLabel =
           test.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
             @Override
             public Tuple2<Double, Double> call(LabeledPoint p) {
               return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
             }
           });
         double accuracy = predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
           @Override
           public Boolean call(Tuple2<Double, Double> pl) {
             return pl._1().equals(pl._2());
           }
         }).count() / (double) test.count();
        //System.out.println(accuracy);
         // Save and load model
         model.save(jsc.sc(), "target/tmp/NaiveBayesModel");
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

如何执行这个任务?
具体完整的代码大家可以从spark的项目上找,或者看本文最后贴出的补充部分代码,基本也就全了。
把代码所在的工程打包(jar),比如打包为XX.jar。
然后执行命令:
spark-submit –class “yourclass” –master yarn XX.jar
就可以了。打包,已经spark-submit命令就不要详细讲了吧?
模型训练完成,可以打印模型的评测结果(准确率),并且将模型保存到hdfs上。注意上面的两个路径都是指hdfs上的路径。

三、使用模型对文本进行分类

拿来展示分类的文本数据生成过程就不介绍了。和模型数据生成一样。基本就是对你的文章分词等,然后转换成libsvm格式的文件,放到hdfs上。下面直接上分类的代码:

public static void predict(JavaSparkContext jsc){
        NaiveBayesModel sameModel = NaiveBayesModel.load(jsc.sc(), "target/tmp/NaiveBayesModel");
        String path = "/data/pred_data.txt";
        JavaRDD<String> rdd = jsc.textFile(path);

        for(String features:rdd.collect()){
            //System.out.println(features);
            String[] feature_str_lst = features.split(",");
            double[] feature_lst = new double[feature_str_lst.length];
            for(int i = 0;i<feature_str_lst.length;i++){
                feature_lst[i] = Double.parseDouble(feature_str_lst[i]);
            }
            System.out.println(sameModel.predict(Vectors.dense(feature_lst)));

        }

    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

执行后的结果正常情况你会看到如下的输出:(框起来的都是预测的类别)

技术分享

还不会,没有执行完整个demo?
代码main()方法的也贴给你们:

public static void main(String[] args){
        SparkConf sparkConf = new SparkConf().setAppName("JavaNaiveBayesExample");
         JavaSparkContext jsc = new JavaSparkContext(sparkConf);
         //training(jsc);
         predict(jsc);
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

好了。讲完了,是不是很简单。实现很简单,算法原理也很简单。有兴趣就自己去研究吧。

基于的朴素贝叶斯的文本分类(附完整代码(spark/java)