
com.tencent.angel.spark.examples.ml.BreezeOWLQN.scala Maven / Gradle / Ivy
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the BSD 3-Clause License (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/BSD-3-Clause
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/
package com.tencent.angel.spark.examples.ml
import scala.collection.mutable.ArrayBuffer
import breeze.linalg.DenseVector
import breeze.optimize.{OWLQN => BrzOWLQN}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.rdd.RDD
import com.tencent.angel.spark.PSContext
import com.tencent.angel.spark.examples.util.{Logistic, PSExamples}
import com.tencent.angel.spark.ml.optim.OWLQN
import com.tencent.angel.spark.models.vector.BreezePSVector
/**
* There is two ways to update PSVectors in RDD, RemotePSVector and RDDFunction.psAggregate.
* This is the samples of running Breeze.optimize BreezeOWLQN in spark and two ways of
* spark-on-angel, respectively.
*/
object BreezeOWLQN {
import PSExamples._
def main(args: Array[String]): Unit = {
parseArgs(args)
runWithSparkContext(this.getClass.getSimpleName) { sc =>
PSContext.getOrCreate(sc)
execute(DIM, N, numSlices, ITERATIONS)
}
}
def execute(dim: Int, sampleNum: Int, partitionNum: Int, maxIter: Int, m: Int = 10): Unit = {
val trainData = Logistic.generateLRData(sampleNum, dim, partitionNum)
// run OWLQN
var startTime = System.currentTimeMillis()
runOWLQN(trainData, dim, m, maxIter)
var endTime = System.currentTimeMillis()
println(s"OWLQN time: ${endTime - startTime}")
// run PS WLQN
startTime = System.currentTimeMillis()
runPsOWLQN(trainData, dim, m, maxIter)
endTime = System.currentTimeMillis()
println(s"PS OWLQN time: ${endTime - startTime} ")
// run PS WLQN
startTime = System.currentTimeMillis()
runPsAggregateOWLQN(trainData, dim, m, maxIter)
endTime = System.currentTimeMillis()
println(s"PS aggregate OWLQN time: ${endTime - startTime} ")
}
def runOWLQN(trainData: RDD[(Vector, Double)], dim: Int, m: Int, maxIter: Int): Unit = {
val tol = 1e-5
val initWeight = new DenseVector[Double](dim)
val l1reg = 0.0
val owlqn = new BrzOWLQN[Int, DenseVector[Double]](maxIter, m, 0.0, tol)
val states = owlqn.iterations(Logistic.Cost(trainData), initWeight)
val lossHistory = new ArrayBuffer[Double]()
var weight = new DenseVector[Double](dim)
while (states.hasNext) {
val state = states.next()
lossHistory += state.value
if (!states.hasNext) {
weight = state.x
}
}
println(s"loss history: ${lossHistory.toArray.mkString(" ")}")
println(s"weights: ${weight.toArray.mkString(" ")}")
}
def runPsOWLQN(trainData: RDD[(Vector, Double)], dim: Int, m: Int, maxIter: Int): Unit = {
val tol = 1e-5
val pool = PSContext.getOrCreate().createModelPool(dim, 20)
val initWeightPS = pool.createZero().mkBreeze()
val l1reg = pool.createZero().mkBreeze()
val owlqn = new OWLQN(maxIter, m, l1reg, tol)
val states = owlqn.iterations(Logistic.PSCost(trainData), initWeightPS)
val lossHistory = new ArrayBuffer[Double]()
var weight: BreezePSVector = null
while (states.hasNext) {
val state = states.next()
lossHistory += state.value
if (!states.hasNext) {
weight = state.x
}
}
println(s"loss history: ${lossHistory.toArray.mkString(" ")}")
println(s"weights: ${weight.toRemote.pull().mkString(" ")}")
}
def runPsAggregateOWLQN(
trainData: RDD[(Vector, Double)], dim: Int, m: Int, maxIter: Int): Unit = {
val tol = 1e-5
val pool = PSContext.getOrCreate().createModelPool(dim, 20)
val initWeightPS = pool.createZero().mkBreeze()
val l1reg = pool.createZero().mkBreeze()
val owlqn = new OWLQN(maxIter, m, l1reg, tol)
val states = owlqn.iterations(Logistic.PSAggregateCost(trainData), initWeightPS)
val lossHistory = new ArrayBuffer[Double]()
var weight: BreezePSVector = null
while (states.hasNext) {
val state = states.next()
lossHistory += state.value
if (!states.hasNext) {
weight = state.x
}
}
println(s"loss history: ${lossHistory.toArray.mkString(" ")}")
println(s"weights: ${weight.toRemote.pull().mkString(" ")}")
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy