首页 > 代码库 > Spark机器学习系列之13: 支持向量机SVM

Spark机器学习系列之13: 支持向量机SVM

基本公式推导

理论部分:SVM涉及的理论知识太多太繁杂了,大家直接看:
支持向量机通俗导论(理解SVM的三层境界) http://blog.csdn.net/v_july_v/article/details/7624837

下面摘抄一小部分内容(不考虑推导细节的话,基本上能理解SVM方法推导的整个流程),对偶问题(包括KKT条件)在SVM起到很重要的作用,如果对此不很了解,则难以理解SVM推导过程。关于对偶分析,可以参考我的另一篇文章:http://blog.csdn.net/qq_34531825/article/details/52872819) 。

技术分享
我们用一个超平面划分图中对图中的两类数据进行分类,超平面写成f(x)=wTx+b=0<script type="math/tex" id="MathJax-Element-757">f(x)=w^Tx+b=0</script>,在线性可分的情况下,我们能找到一些支持向量,满足|wTx+b|=1<script type="math/tex" id="MathJax-Element-758">|w^Tx+b|=1</script>

对一个数据点进行分类,当超平面离数据点的“间隔”越大,分类的确信度(confidence)也越大。所以,为了使得分类的确信度尽量高,需要让所选择的超平面能够最大化这个“间隔”值。这个间隔如下图中的Gap2<script type="math/tex" id="MathJax-Element-759">\frac{Gap}{ 2}</script>所示。
技术分享

xi<script type="math/tex" id="MathJax-Element-760">x_i</script>点离超平面方程wTx+b=0<script type="math/tex" id="MathJax-Element-761">w^Tx+b=0</script>的距离为:

γ=|wTxi+b|||w||=|f(xi)|||w||
<script type="math/tex; mode=display" id="MathJax-Element-762">\gamma=\frac{|w^Tx_i+b| }{||w||}=\frac{|f(x_i)| }{||w||}</script>
定义函数
yi=1      f(x)>0,<script type="math/tex" id="MathJax-Element-763">y_i=1 \ \ \ \ \ \ 当f(x)>0,</script>
yi=?1   f(x)<0,<script type="math/tex" id="MathJax-Element-764">y_i=-1\ \ \ 当f(x)<0,</script>

需要求解的目标函数及约束条件为:

max  1||w||
<script type="math/tex; mode=display" id="MathJax-Element-765">max\ \ \frac{1}{||w||}</script>
s.t.   yi(wTxi+b)1   i=1,2...n
<script type="math/tex; mode=display" id="MathJax-Element-766">s.t. \ \ \ y_i(w^Tx_i+b)\geq 1\ \ \ i=1,2...n</script>
等价于下面的凸二次规划问题:
min  12||w||2
<script type="math/tex; mode=display" id="MathJax-Element-767">min\ \ \frac{1}{2}||w||^2</script>
s.t.   yi(wTxi+b)1   i=1,2...n
<script type="math/tex; mode=display" id="MathJax-Element-768">s.t. \ \ \ y_i(w^Tx_i+b)\geq 1\ \ \ i=1,2...n </script>

可以利用通过拉格朗日对偶性(Lagrange Duality)变换到对偶变量 (dual variable) 的优化问题。(关于对偶分析,可以参考我的另一篇文章:http://blog.csdn.net/qq_34531825/article/details/52872819) 。

拉格朗日函数为:

L(w,b,α)=12||w||2?i=1nαi(yi(wTxi+b)?1)
<script type="math/tex; mode=display" id="MathJax-Element-769">L(w,b,\alpha)=\frac{1}{2}||w||^2-\sum_{i=1}^{n}\alpha_i(y_i(w^Tx_i+b)-1)</script>
原问题的对偶函数为:
θ(w,b,α)=infαi0{L(w,b,α)}
<script type="math/tex; mode=display" id="MathJax-Element-770">\theta(w,b,\alpha)=inf_{\alpha_i \geq 0}\{L(w,b,\alpha)\}</script>
对偶问题为:
min  θ((w,b,α))
<script type="math/tex; mode=display" id="MathJax-Element-771">min\ \ \theta((w,b,\alpha))</script>
αi0
<script type="math/tex; mode=display" id="MathJax-Element-772">\alpha_i\geq0</script>
按照KKT条件求解:
首先对w和b求偏导:
?L?w=0=>w=i=1nαiyixi
<script type="math/tex; mode=display" id="MathJax-Element-773">\frac{\partial L}{\partial w}=0=>w=\sum_{i=1}^{n}\alpha_iy_ix_i</script>
?L?b=0=>i=1nαiyi=0
<script type="math/tex; mode=display" id="MathJax-Element-774">\frac{\partial L}{\partial b}=0=>\sum_{i=1}^{n}\alpha_iy_i=0</script>
技术分享
问题转换为下公式,只包含变量α<script type="math/tex" id="MathJax-Element-775">\alpha</script>
技术分享

对于非线性数据,用核函数进行映射xi??(yi),yi??(yi)<script type="math/tex" id="MathJax-Element-776">x_i\mapsto \phi(y_i),y_i\mapsto \phi(y_i)</script>:
技术分享

技术分享

使用松弛变量处理 outliers 方法
虽然通过映射 将原始数据映射到高维空间之后,能够线性分隔的概率大大增加,但是对于某些情况还是很难处理。
例如可能并不是因为数据本身是非线性结构的,而只是因为数据有噪音。对于这种偏离正常位置很远的数据点,我们称之为 outlier ,在我们原来的 SVM 模型里,outlier 的存在有可能造成很大的影响,因为超平面本身就是只有少数几个 support vector 组成的,如果这些 support vector 里又存在 outlier 的话,其影响就很大了。
技术分享
现在考虑到outlier问题,引入松弛变量:
技术分享
原问题变成
技术分享
其中 C<script type="math/tex" id="MathJax-Element-777">C</script>是一个参数,用于控制目标函数中两项(“寻找 margin 最大的超平面”和“保证数据点偏差量最小”)之间的权重。C<script type="math/tex" id="MathJax-Element-778">C</script>是一个事先确定好的常量。
对偶问题:
技术分享
序列最小最优化SMO算法:
α={α1,α2,...,αn}<script type="math/tex" id="MathJax-Element-779">\alpha=\{\alpha_1, \alpha_2, ..., \alpha_n\}</script>上求上述目标函数的最小值。为了求解这些乘子,每次从中任意抽取两个乘子α1,α2<script type="math/tex" id="MathJax-Element-780">\alpha_1, \alpha_2</script>,然后固定α1,α2<script type="math/tex" id="MathJax-Element-781">\alpha_1, \alpha_2</script>以外的其它乘子α3,α4,...,αn<script type="math/tex" id="MathJax-Element-782">{\alpha_3, \alpha_4, ..., \alpha_n}</script>,使得目标函数只是关于α1,α2<script type="math/tex" id="MathJax-Element-783">\alpha_1, \alpha_2</script>的函数。这样,不断的从一堆乘子中任意抽取两个求解,不断的迭代求解子问题,最终达到求解原问题的目的。
技术分享

以上内容摘抄自:
支持向量机通俗导论(理解SVM的三层境界) http://blog.csdn.net/v_july_v/article/details/7624837
希望了解全面的过程,推荐参考原文。

Spark 优缺点分析

以下翻译自Scikit。
The advantages of support vector machines are:
(1)Effective in high dimensional spaces.在高维空间表现良好。
(2)Still effective in cases where number of dimensions is greater than the number of samples.在数据维度大于样本点数时候,依然可以起作用
(3)Uses a subset of training points in the decision function (called support vectors), so it is also memory efficient.仅仅使用训练数据的一个子集(支持向量),因此是内存友好型的算法。
(4)Versatile: different Kernel functions can be specified for the decision function. Common kernels are provided, but it is also possible to specify custom kernels.适应广,能解决多种情况下的分类问题,这是由于它支持不同类型的核函数,甚至支持自定义的核函数。
The disadvantages of support vector machines include:
(1)If the number of features is much greater than the number of samples, the method is likely to give poor performances.在数据维度(特征个数)多于样本数很多的时候,通常只能训练出一个表现很差的模型。
(2)SVMs do not directly provide probability estimates, these are calculated using an expensive five-fold cross-validation (see Scores and probabilities, below).SVM不支持直接进行概率估计,Scikit中使用很耗费资源的5折交叉检验来估计概率。

Spark Mllib

技术分享
技术分享
技术分享

import org.apache.spark.{SparkConf,SparkContext}
import org.apache.log4j.{Level, Logger}
import org.apache.spark.mllib.classification.{SVMModel,SVMWithSGD}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.mllib.optimization.L1Updater
import org.apache.spark.mllib.optimization.SquaredL2Updater

object mySVM {
  def main(args:Array[String]){
    //屏蔽日志
    Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
    Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF) 
    val conf=new SparkConf().setMaster("local").setAppName("My App")
    val sc=new SparkContext(conf)

    // Load training data in LIBSVM format.
    val data = http://www.mamicode.com/MLUtils.loadLibSVMFile(sc, "/data/mllib/sample_libsvm_data.txt")
    //println(data.collect()(0))//检查数据   

    // Split data into training (60%) and test (40%).
    val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
    val training = splits(0).cache()
    val test = splits(1)

    // Run training algorithm to build the model   
    /*
     * stepSize: 迭代步长,默认为1.0
     * numIterations: 迭代次数,默认为100
     * regParam: 正则化参数,默认值为0.0
     * miniBatchFraction: 每次迭代参与计算的样本比例,默认为1.0
     * gradient:HingeGradient (),梯度下降;
     * updater:SquaredL2Updater (),正则化,L2范数;
     * optimizer:GradientDescent (gradient, updater),梯度下降最优化计算。
     */
    val svmAlg=new SVMWithSGD()
    svmAlg.optimizer
      .setNumIterations(100)
      .setRegParam(0.1)//正则化参数
      .setUpdater(new L1Updater)
    val modelL1=svmAlg.run(training)


    // Clear the default threshold.
    modelL1.clearThreshold()

   // Compute raw scores on the test set.
    val scoreAndLabels = test.map { point =>
     val score = modelL1.predict(point.features)     
     (score, point.label)//return score and label
    }   

   // Get evaluation metrics.
   val metrics = new BinaryClassificationMetrics(scoreAndLabels)
   val auROC = metrics.areaUnderROC() 
   println("Area under ROC = " + auROC)  

  }   

}

Python Scikit

Scikit对与SVM提供更多灵活的选择,可供学习实验。

参考文献
(1) 支持向量机通俗导论(理解SVM的三层境界) http://blog.csdn.net/v_july_v/article/details/7624837
(2)Spark MLlib SVM算法(源代码分析)http://www.itnose.net/detail/6267193.html
(3)Spark官网与Scikit官网
(4)LIBSVM: A Library for Support Vector Machines
Chih-Chung Chang and Chih-Jen Lin,Department of Computer Science
National Taiwan University, Taipei, Taiwan
(5)A Tutorial on Support Vector Regression Alex J. Smola? and Bernhard Scholkopf

<script type="text/javascript"> $(function () { $(‘pre.prettyprint code‘).each(function () { var lines = $(this).text().split(‘\n‘).length; var $numbering = $(‘
    ‘).addClass(‘pre-numbering‘).hide(); $(this).addClass(‘has-numbering‘).parent().append($numbering); for (i = 1; i <= lines; i++) { $numbering.append($(‘
  • ‘).text(i)); }; $numbering.fadeIn(1700); }); }); </script>

    Spark机器学习系列之13: 支持向量机SVM