首页 > 代码库 > Spark机器学习(4):朴素贝叶斯算法

Spark机器学习(4):朴素贝叶斯算法

1. 贝叶斯定理

条件概率公式:

技术分享

这个公式非常简单,就是计算在B发生的情况下,A发生的概率。但是很多时候,我们很容易知道P(A|B),需要计算的是P(B|A),这时就要用到贝叶斯定理:

技术分享

2. 朴素贝叶斯分类

朴素贝叶斯分类的推导过程就不详述了,其流程可以简单的用一张图来表示:

技术分享

 

举个简单的例子来说,下面这张表说明了各地区的人口构成:

技术分享

这个时候如果一个黑皮肤的人走过来(一个待分类项(0,0,1)),他是来自欧美,亚洲还是非洲呢?可以根据朴素贝叶斯分类进行计算:

欧美=0.30×0.90×0.20×0.40=0.0216

亚洲=0.95×0.10×0.05×0.40=0.0019

非洲=0.90×1.00×0.90×0.20=0.1620

即他来自非洲的可能性最大,来自欧美的可能性次之,来自亚洲的可能性最小,那么我们就判断他来自非洲,这和我们日常生活中的经验是一致的。

如果特征属性是连续值,则按照下面的公式计算:

技术分享

技术分享

3. MLlib的贝叶斯分类

直接上代码:

import org.apache.log4j.{Level, Logger}import org.apache.spark.mllib.classification.NaiveBayesimport org.apache.spark.mllib.linalg.Vectorsimport org.apache.spark.mllib.regression.LabeledPointimport org.apache.spark.{SparkConf, SparkContext}object NaiveBayesTest {  def main(args: Array[String]) {    // 设置运行环境    val conf = new SparkConf().setAppName("Naive Bayes Test")      .setMaster("spark://master:7077").setJars(Seq("E:\\Intellij\\Projects\\MachineLearning\\MachineLearning.jar"))    val sc = new SparkContext(conf)    Logger.getRootLogger.setLevel(Level.WARN)    // 读取样本数据并解析    val dataRDD = sc.textFile("hdfs://master:9000/ml/data/sample_naive_bayes_data.txt")    val parsedDataRDD = dataRDD.map { line =>      val parts = line.split(‘,‘)      LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(‘ ‘).map(_.toDouble)))    }    // 样本数据划分,训练样本占0.8,测试样本占0.2    val dataParts = parsedDataRDD.randomSplit(Array(0.8, 0.2))    val trainRDD = dataParts(0)    val testRDD = dataParts(1)    // 建立贝叶斯分类模型并训练    val model = NaiveBayes.train(trainRDD, lambda = 1.0, modelType = "multinomial")    // 对测试样本进行测试    val predictionAndLabel = testRDD.map(p => (model.predict(p.features), p.label, p.features))    val showPredict = predictionAndLabel.take(50)    println("Prediction" + "\t" + "Label" + "\t" + "Data")    for (i <- 0 to showPredict.length - 1) {      println(showPredict(i)._1 + "\t" + showPredict(i)._2 + "\t" + showPredict(i)._3)    }    val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / testRDD.count()    println("Accuracy=" + accuracy)  }}

其中,NaiveBayes是贝叶斯分类伴生对象,train方法进行模型训练,三个参数分别是训练样本,平滑参数和模型类别。模型类别有两个:multinomial(多项式)和bernoulli(伯努利),这里使用的是multinomial。predict方法根据特征值进行判断分类。

运行结果:

技术分享

Spark机器学习(4):朴素贝叶斯算法