首页 > 代码库 > Apache Spark源码走读之22 -- Spark MLLib中拟牛顿法L-BFGS的源码实现
Apache Spark源码走读之22 -- Spark MLLib中拟牛顿法L-BFGS的源码实现
欢迎转载,转载请注明出处,徽沪一郎。
概要
本文就拟牛顿法L-BFGS的由来做一个简要的回顾,然后就其在spark mllib中的实现进行源码走读。
拟牛顿法
数学原理
代码实现
L-BFGS算法中使用到的正则化方法是SquaredL2Updater。
算法实现上使用到了由scalanlp的成员项目breeze库中的BreezeLBFGS函数,mllib中自定义了BreezeLBFGS所需要的DiffFunctions.
runLBFGS函数的源码实现如下
def runLBFGS( data: RDD[(Double, Vector)], gradient: Gradient, updater: Updater, numCorrections: Int, convergenceTol: Double, maxNumIterations: Int, regParam: Double, initialWeights: Vector): (Vector, Array[Double]) = { val lossHistory = new ArrayBuffer[Double](maxNumIterations) val numExamples = data.count() val costFun = new CostFun(data, gradient, updater, regParam, numExamples) val lbfgs = new BreezeLBFGS[BDV[Double]](maxNumIterations, numCorrections, convergenceTol) val states = lbfgs.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector) /** * NOTE: lossSum and loss is computed using the weights from the previous iteration * and regVal is the regularization value computed in the previous iteration as well. */ var state = states.next() while(states.hasNext) { lossHistory.append(state.value) state = states.next() } lossHistory.append(state.value) val weights = Vectors.fromBreeze(state.x) logInfo("LBFGS.runLBFGS finished. Last 10 losses %s".format( lossHistory.takeRight(10).mkString(", "))) (weights, lossHistory.toArray) }
costFun函数是算法实现中的重点
private class CostFun( data: RDD[(Double, Vector)], gradient: Gradient, updater: Updater, regParam: Double, numExamples: Long) extends DiffFunction[BDV[Double]] { private var i = 0 override def calculate(weights: BDV[Double]) = { // Have a local copy to avoid the serialization of CostFun object which is not serializable. val localData = http://www.mamicode.com/data>
Apache Spark源码走读之22 -- Spark MLLib中拟牛顿法L-BFGS的源码实现
声明:以上内容来自用户投稿及互联网公开渠道收集整理发布,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任,若内容有误或涉及侵权可进行投诉: 投诉/举报 工作人员会在5个工作日内联系你,一经查实,本站将立刻删除涉嫌侵权内容。