
com.tencent.angel.spark.examples.ml.BreezeSGD.scala Maven / Gradle / Ivy
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the BSD 3-Clause License (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/BSD-3-Clause
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/
package com.tencent.angel.spark.examples.ml
import scala.collection.mutable.ArrayBuffer
import breeze.linalg.DenseVector
import breeze.optimize.StochasticGradientDescent
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.rdd.RDD
import com.tencent.angel.spark.PSContext
import com.tencent.angel.spark.examples.util.{Logistic, PSExamples}
import com.tencent.angel.spark.models.vector.BreezePSVector
/**
* There is two ways to update PSVectors in RDD, RemotePSVector and RDDFunction.psAggregate.
* This is the samples of running Breeze.optimize SGD in spark and two ways of spark-on-angel,
* respectively.
*/
object BreezeSGD {
import PSExamples._
def main(args: Array[String]): Unit = {
parseArgs(args)
runWithSparkContext(this.getClass.getSimpleName) { sc =>
PSContext.getOrCreate(sc)
execute(DIM, N, numSlices, ITERATIONS)
PSContext.stop()
}
}
def execute(
dim: Int,
sampleNum: Int,
partitionNum: Int,
maxIter: Int, stepSize:
Double = 0.1): Unit = {
val trainData = Logistic.generateLRData(sampleNum, dim, partitionNum)
// runSGD
var startTime = System.currentTimeMillis()
runSGD(trainData, dim, stepSize, maxIter)
var endTime = System.currentTimeMillis()
println(s"SGD time: ${endTime - startTime}")
// run PS SGD
startTime = System.currentTimeMillis()
runPsSGD(trainData, dim, stepSize, maxIter)
endTime = System.currentTimeMillis()
println(s"PS SGD time: ${endTime - startTime} ")
// run PS aggregate SGD
startTime = System.currentTimeMillis()
runPsAggregateSGD(trainData, dim, stepSize, maxIter)
endTime = System.currentTimeMillis()
println(s"PS aggregate SGD time: ${endTime - startTime} ")
}
def runSGD(trainData: RDD[(Vector, Double)], dim: Int, stepSize: Double, maxIter: Int): Unit = {
val initWeight = new DenseVector[Double](dim)
val sgd = StochasticGradientDescent[DenseVector[Double]](stepSize, maxIter)
val states = sgd.iterations(Logistic.Cost(trainData), initWeight)
val lossHistory = new ArrayBuffer[Double]()
var weight = new DenseVector[Double](dim)
while (states.hasNext) {
val state = states.next()
lossHistory += state.value
if (!states.hasNext) {
weight = state.x
}
}
println(s"loss history: ${lossHistory.toArray.mkString(" ")}")
println(s"weights: ${weight.toArray.mkString(" ")}")
}
def runPsSGD(trainData: RDD[(Vector, Double)], dim: Int, stepSize: Double, maxIter: Int): Unit = {
val pool = PSContext.getOrCreate().createModelPool(dim, 10)
val initWeightPS = pool.createZero().mkBreeze()
val sgd = StochasticGradientDescent[BreezePSVector](stepSize, maxIter)
val states = sgd.iterations(Logistic.PSCost(trainData), initWeightPS)
val lossHistory = new ArrayBuffer[Double]()
var weight: BreezePSVector = null
while (states.hasNext) {
val state = states.next()
lossHistory += state.value
if (!states.hasNext) {
weight = state.x
}
}
println(s"loss history: ${lossHistory.toArray.mkString(" ")}")
println(s"weights: ${weight.toRemote.pull().mkString(" ")}")
PSContext.getOrCreate().destroyVectorPool(pool)
}
def runPsAggregateSGD(
trainData: RDD[(Vector, Double)],
dim: Int,
stepSize: Double,
maxIter: Int): Unit = {
val pool = PSContext.getOrCreate().createModelPool(dim, 10)
val initWeightPS = pool.createZero().mkBreeze()
val sgd = StochasticGradientDescent[BreezePSVector](stepSize, maxIter)
val states = sgd.iterations(Logistic.PSAggregateCost(trainData), initWeightPS)
val lossHistory = new ArrayBuffer[Double]()
var weight: BreezePSVector = null
while (states.hasNext) {
val state = states.next()
lossHistory += state.value
if (!states.hasNext) {
weight = state.x
}
}
println(s"loss history: ${lossHistory.toArray.mkString(" ")}")
println(s"weights: ${weight.toRemote.pull().mkString(" ")}")
PSContext.getOrCreate().destroyVectorPool(pool)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy