All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.spark.examples.ml.BreezeLBFGS.scala Maven / Gradle / Ivy

/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the BSD 3-Clause License (the "License"); you may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/BSD-3-Clause
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 *
 */

package com.tencent.angel.spark.examples.ml

import scala.collection.mutable.ArrayBuffer

import breeze.linalg.DenseVector
import breeze.optimize.LBFGS
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.rdd.RDD

import com.tencent.angel.spark.PSContext
import com.tencent.angel.spark.examples.util.{Logistic, PSExamples}
import com.tencent.angel.spark.models.vector.BreezePSVector

/**
 * There is two ways to update PSVectors in RDD, RemotePSVector and RDDFunction.psAggregate.
 * This is the samples of running Breeze.optimize LBFGS in spark and two ways of spark-on-angel,
 * respectively.
 */
object BreezeLBFGS {
  import PSExamples._
  def main(args: Array[String]): Unit = {
    parseArgs(args)
    runWithSparkContext(this.getClass.getSimpleName) { sc =>
      PSContext.getOrCreate(sc)
      execute(DIM, N, numSlices, ITERATIONS)
    }
  }
  def execute(dim: Int, sampleNum: Int, partitionNum: Int, maxIter: Int, m: Int = 10): Unit = {

    val trainData = Logistic.generateLRData(sampleNum, dim, partitionNum)

    // run LBFGS
    var startTime = System.currentTimeMillis()
    runLBFGS(trainData, dim, m, maxIter)
    var endTime = System.currentTimeMillis()
    println(s"LBFGS time: ${endTime - startTime}")

    // run PS LBFGS
    startTime = System.currentTimeMillis()
    runPsLBFGS(trainData, dim, m, maxIter)
    endTime = System.currentTimeMillis()
    println(s"PS LBFGS time: ${endTime - startTime} ")

    // run PS aggregate LBFGS
    startTime = System.currentTimeMillis()
    runPsAggregateLBFGS(trainData, dim, m, maxIter)
    endTime = System.currentTimeMillis()
    println(s"PS aggregate LBFGS time: ${endTime - startTime} ")
  }

   def runLBFGS(trainData: RDD[(Vector, Double)], dim: Int, m: Int, maxIter: Int): Unit = {
    val tol = 1e-5
    val initWeight = new DenseVector[Double](dim)
    val lbfgs = new LBFGS[DenseVector[Double]](maxIter, m, tol)
    val states = lbfgs.iterations(Logistic.Cost(trainData), initWeight)

    val lossHistory = new ArrayBuffer[Double]()
    var weight = new DenseVector[Double](dim)
    while (states.hasNext) {
      val state = states.next()
      lossHistory += state.value
      if (!states.hasNext) {
        weight = state.x
      }
    }

    println(s"loss history: ${lossHistory.toArray.mkString(" ")}")
    println(s"weights: ${weight.toArray.mkString(" ")}")
  }

  def runPsLBFGS(trainData: RDD[(Vector, Double)], dim: Int, m: Int, maxIter: Int): Unit = {
    val tol = 1e-5
    val psContext = PSContext.getOrCreate()
    val pool = psContext.createModelPool(dim, 20)
    val initWeightPSModel = pool.createZero().mkBreeze()
    val lbfgs = new LBFGS[BreezePSVector](maxIter, m, tol)
    val states = lbfgs.iterations(Logistic.PSCost(trainData), initWeightPSModel)

    val lossHistory = new ArrayBuffer[Double]()
    var weight: BreezePSVector = null
    while (states.hasNext) {
      val state = states.next()
      lossHistory += state.value

      if (!states.hasNext) {
        weight = state.x
      }
    }
    println(s"loss history: ${lossHistory.toArray.mkString(" ")}")
    println(s"weights: ${weight.toRemote.pull().mkString(" ")}")
    psContext.destroyVectorPool(pool)
  }

  def runPsAggregateLBFGS(
      trainData: RDD[(Vector, Double)], dim: Int, m: Int, maxIter: Int): Unit = {
    val tol = 1e-5
    val pool = PSContext.getOrCreate().createModelPool(dim, 20)
    val initWeightPS = pool.createZero().mkBreeze()
    val lbfgs = new LBFGS[BreezePSVector](maxIter, m, tol)
    val states = lbfgs.iterations(Logistic.PSAggregateCost(trainData), initWeightPS)

    val lossHistory = new ArrayBuffer[Double]()
    var weight: BreezePSVector = null
    while (states.hasNext) {
      val state = states.next()
      lossHistory += state.value

      if (!states.hasNext) {
        weight = state.x
      }
    }
    println(s"loss history: ${lossHistory.toArray.mkString(" ")}")
    println(s"weights: ${weight.toRemote.pull().mkString(" ")}")
    PSContext.getOrCreate().destroyVectorPool(pool)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy