首页 > 代码库 > Apache Spark源码走读之22 -- Spark MLLib中拟牛顿法L-BFGS的源码实现

Apache Spark源码走读之22 -- Spark MLLib中拟牛顿法L-BFGS的源码实现

欢迎转载,转载请注明出处,徽沪一郎。

概要

本文就拟牛顿法L-BFGS的由来做一个简要的回顾,然后就其在spark mllib中的实现进行源码走读。

拟牛顿法

数学原理

 

代码实现

L-BFGS算法中使用到的正则化方法是SquaredL2Updater。

算法实现上使用到了由scalanlp的成员项目breeze库中的BreezeLBFGS函数,mllib中自定义了BreezeLBFGS所需要的DiffFunctions.



runLBFGS函数的源码实现如下

def runLBFGS(      data: RDD[(Double, Vector)],      gradient: Gradient,      updater: Updater,      numCorrections: Int,      convergenceTol: Double,      maxNumIterations: Int,      regParam: Double,      initialWeights: Vector): (Vector, Array[Double]) = {    val lossHistory = new ArrayBuffer[Double](maxNumIterations)    val numExamples = data.count()    val costFun =      new CostFun(data, gradient, updater, regParam, numExamples)    val lbfgs = new BreezeLBFGS[BDV[Double]](maxNumIterations, numCorrections, convergenceTol)    val states =      lbfgs.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector)    /**     * NOTE: lossSum and loss is computed using the weights from the previous iteration     * and regVal is the regularization value computed in the previous iteration as well.     */    var state = states.next()    while(states.hasNext) {      lossHistory.append(state.value)      state = states.next()    }    lossHistory.append(state.value)    val weights = Vectors.fromBreeze(state.x)    logInfo("LBFGS.runLBFGS finished. Last 10 losses %s".format(      lossHistory.takeRight(10).mkString(", ")))    (weights, lossHistory.toArray)  }

costFun函数是算法实现中的重点

private class CostFun(    data: RDD[(Double, Vector)],    gradient: Gradient,    updater: Updater,    regParam: Double,    numExamples: Long) extends DiffFunction[BDV[Double]] {    private var i = 0    override def calculate(weights: BDV[Double]) = {      // Have a local copy to avoid the serialization of CostFun object which is not serializable.      val localData = http://www.mamicode.com/data>

Apache Spark源码走读之22 -- Spark MLLib中拟牛顿法L-BFGS的源码实现