![JAR search and dependency download from the Maven repository](/logo.png)
com.tencent.angel.sona.graph.embedding.line2.LINEModel.scala Maven / Gradle / Ivy
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/
package com.tencent.angel.sona.graph.embedding.line2
import java.text.SimpleDateFormat
import java.util.Date
import com.tencent.angel.ml.math2.utils.RowType
import com.tencent.angel.ml.matrix.MatrixContext
import com.tencent.angel.model.output.format.SnapshotFormat
import com.tencent.angel.model.{MatrixLoadContext, MatrixSaveContext, ModelLoadContext, ModelSaveContext}
import com.tencent.angel.sona.context.PSContext
import com.tencent.angel.sona.models.PSMatrix
import it.unimi.dsi.fastutil.ints.{Int2IntOpenHashMap, Int2ObjectOpenHashMap}
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkContext
import com.tencent.angel.sona.graph.embedding.NEModel.NEDataSet
import com.tencent.angel.sona.graph.embedding.line.LINEModel.{LINEDataSet, buildDataBatches}
import com.tencent.angel.sona.graph.embedding.{FastSigmoid, Param}
import com.tencent.angel.sona.psf.embedding.NEModelRandomize.RandomizeUpdateParam
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import scala.util.Random
class LINEModel(numNode: Int,
dimension: Int,
numPart: Int,
numNodesPerRow: Int = -1,
order: Int = 1,
seed: Int = Random.nextInt) extends Serializable {
val matrixName = "embedding"
// Create one ps matrix to hold the input vectors and the output vectors for all node
val mc: MatrixContext = new MatrixContext(matrixName, 1, numNode)
mc.setMaxRowNumInBlock(1)
mc.setMaxColNumInBlock(numNode / numPart)
mc.setRowType(RowType.T_ANY_INTKEY_DENSE)
mc.setValueType(classOf[LINENode])
mc.setInitFunc(new LINEInitFunc(order, dimension))
val psMatrix: PSMatrix = PSMatrix.matrix(mc)
val matrixId: Int = psMatrix.id
// initialize embeddings
def randomInitialize(seed: Int) = {
val beforeRandomize = System.currentTimeMillis()
psMatrix.psfUpdate(new LINEModelRandomize(new RandomizeUpdateParam(matrixId, dimension / numPart, dimension, order, seed))).get()
logTime(s"Model successfully Randomized, cost ${(System.currentTimeMillis() - beforeRandomize) / 1000.0}s")
}
private val rand = new Random(seed)
def this(param: Param) {
this(param.maxIndex, param.embeddingDim, param.numPSPart, param.nodesNumPerRow, param.order, param.seed)
}
def train(trainSet: RDD[(Int, Int)], params: Param, path: String): this.type = {
// Get mini-batch data set
val trainBatches = buildDataBatches(trainSet, params.batchSize)
val numEpoch = params.numEpoch
val learningRate = params.learningRate
val checkpointInterval = params.checkpointInterval
val saveModelInterval = params.saveModelInterval
val negative = params.negSample
var startTs = System.currentTimeMillis()
// Before training, checkpoint the model
psMatrix.checkpoint(0)
logTime(s"Write checkpoint use time=${System.currentTimeMillis() - startTs}")
for (epoch <- 1 to numEpoch) {
val alpha = learningRate
val data = trainBatches.next()
val numPartitions = data.getNumPartitions
val middle = data.mapPartitionsWithIndex((partitionId, iterator) =>
sgdForPartition(partitionId, iterator, numPartitions, negative, alpha),
preservesPartitioning = true
).collect()
val loss = middle.map(f => f._1).sum / middle.map(_._2).sum.toFloat
val array = new Array[Long](6)
middle.foreach(f => f._3.zipWithIndex.foreach(t => array(t._2) += t._1))
logTime(s"epoch=$epoch " +
f"loss=$loss%2.4f " +
s"sampleTime=${array(0)} getEmbeddingTime=${array(1)} " +
s"dotTime=${array(2)} gradientTime=${array(3)} calUpdateTime=${array(4)} pushTime=${array(5)} " +
s"total=${middle.map(_._2).sum.toFloat} lossSum=${middle.map(_._1).sum} ")
if (epoch % checkpointInterval == 0 && epoch < numEpoch) {
logTime(s"Epoch=${epoch}, checkpoint the model")
startTs = System.currentTimeMillis()
psMatrix.checkpoint(epoch)
logTime(s"checkpoint use time=${System.currentTimeMillis() - startTs}")
}
if (epoch % saveModelInterval == 0 && epoch < numEpoch) {
logTime(s"Epoch=${epoch}, save the model")
startTs = System.currentTimeMillis()
save(path, epoch)
logTime(s"save use time=${System.currentTimeMillis() - startTs}")
}
}
this
}
def sgdForPartition(partitionId: Int,
iterator: Iterator[NEDataSet],
numPartitions: Int,
negative: Int,
alpha: Float): Iterator[(Float, Long, Array[Long])] = {
PSContext.instance()
iterator.zipWithIndex.map { case (batch, index) =>
sgdForBatch(partitionId, batch.asInstanceOf[LINEDataSet], negative, index, alpha, rand.nextInt())
}
}
def sgdForBatch(partitionId: Int,
batch: LINEDataSet,
negative: Int,
batchId: Int,
alpha: Float, seed: Int): (Float, Long, Array[Long]) = {
var start = 0L
start = System.currentTimeMillis()
val srcNodes = batch.src
val destNodes = batch.dst
val negativeSamples = negativeSample(srcNodes, negative, seed)
val sampleTime = System.currentTimeMillis() - start
// Get node embedding from PS
start = System.currentTimeMillis()
val getResult = getEmbedding(srcNodes, destNodes, negativeSamples, negative)
val srcFeats: Int2ObjectOpenHashMap[Array[Float]] = getResult._1
val targetFeats: Int2ObjectOpenHashMap[Array[Float]] = getResult._2
val getEmbeddingTime = System.currentTimeMillis() - start
// Get dot values
start = System.currentTimeMillis()
val dots = dot(srcNodes, destNodes, negativeSamples, srcFeats, targetFeats, negative)
val dotTime = System.currentTimeMillis() - start
// gradient
start = System.currentTimeMillis()
val loss = doGrad(dots, negative, alpha)
val gradientTime = System.currentTimeMillis() - start
start = System.currentTimeMillis()
val (inputUpdates, outputUpdates) = adjust(srcNodes, destNodes, negativeSamples, srcFeats, targetFeats, negative, dots)
val calUpdateTime = System.currentTimeMillis() - start
start = System.currentTimeMillis()
psMatrix.psfUpdate(new LINEAdjust(new LINEAdjustParam(matrixId, inputUpdates, outputUpdates, order)))
val pushTime = System.currentTimeMillis() - start
if(batchId % 100 == 0) {
logTime(s"batchId=${batchId} loss=$loss sampleTime=${sampleTime} getEmbeddingTime=${getEmbeddingTime} " +
s"dotTime=${dotTime} gradientTime=${gradientTime} calUpdateTime=${calUpdateTime} " +
s"pushTime=${pushTime}")
}
(loss, dots.length.toLong, Array(sampleTime, getEmbeddingTime, dotTime, gradientTime, calUpdateTime, pushTime))
}
def getEmbedding(srcNodes: Array[Int], destNodes: Array[Int], negativeSamples: Array[Array[Int]], negative: Int) = {
psMatrix.psfGet(new LINEGetEmbedding(new LINEGetEmbeddingParam(matrixId, srcNodes, destNodes,
negativeSamples, order, negative))).asInstanceOf[LINEGetEmbeddingResult].getResult
}
def dot(srcNodes: Array[Int], destNodes: Array[Int], negativeSamples: Array[Array[Int]],
srcFeats: Int2ObjectOpenHashMap[Array[Float]], targetFeats: Int2ObjectOpenHashMap[Array[Float]], negative: Int): Array[Float] = {
val dots: Array[Float] = new Array[Float]((1 + negative) * srcNodes.length)
if (order == 1) {
var docIndex = 0
for (i <- 0 until srcNodes.length) {
val srcVec = srcFeats.get(srcNodes(i))
// Get dot value for (src, dst)
dots(docIndex) = arraysDot(srcVec, srcFeats.get(destNodes(i)))
docIndex += 1
// Get dot value for (src, negative sample)
for (j <- 0 until negative) {
dots(docIndex) = arraysDot(srcVec, srcFeats.get(negativeSamples(i)(j)))
docIndex += 1
}
}
dots
} else {
var docIndex = 0
for (i <- 0 until srcNodes.length) {
val srcVec = srcFeats.get(srcNodes(i))
// Get dot value for (src, dst)
dots(docIndex) = arraysDot(srcVec, targetFeats.get(destNodes(i)))
docIndex += 1
// Get dot value for (src, negative sample)
for (j <- 0 until negative) {
dots(docIndex) = arraysDot(srcVec, targetFeats.get(negativeSamples(i)(j)))
docIndex += 1
}
}
dots
}
}
def arraysDot(x: Array[Float], y: Array[Float]): Float = {
var dotValue = 0.0f
(0 until x.length).foreach(i => dotValue += x(i) * y(i))
dotValue
}
def axpy(y: Array[Float], x: Array[Float], a: Float) = {
(0 until x.length).foreach(i => y(i) += a * x(i))
}
def div(x: Array[Float], f: Float): Unit = {
(0 until x.length).foreach(i => x(i) = x(i) / f)
}
def adjust(srcNodes: Array[Int], destNodes: Array[Int], negativeSamples: Array[Array[Int]],
srcFeats: Int2ObjectOpenHashMap[Array[Float]], targetFeats: Int2ObjectOpenHashMap[Array[Float]],
negative: Int, dots: Array[Float]) = {
if (order == 1) {
val inputUpdateCounter = new Int2IntOpenHashMap(srcFeats.size())
val inputUpdates = new Int2ObjectOpenHashMap[Array[Float]](srcFeats.size())
var docIndex = 0
for (i <- 0 until srcNodes.length) {
// Src node grad
val neule = new Array[Float](dimension)
// Accumulate dst node embedding to neule
val dstEmbedding = srcFeats.get(destNodes(i))
var g = dots(docIndex)
axpy(neule, dstEmbedding, g)
// Use src node embedding to update dst node embedding
val srcEmbedding = srcFeats.get(srcNodes(i))
merge(inputUpdateCounter, inputUpdates, destNodes(i), g, srcEmbedding)
docIndex += 1
// Use src node embedding to update negative sample node embedding; Accumulate negative sample node embedding to neule
for (j <- 0 until negative) {
val negSampleEmbedding = srcFeats.get(negativeSamples(i)(j))
g = dots(docIndex)
// Accumulate negative sample node embedding to neule
axpy(neule, negSampleEmbedding, g)
// Use src node embedding to update negative sample node embedding
merge(inputUpdateCounter, inputUpdates, negativeSamples(i)(j), g, srcEmbedding)
docIndex += 1
}
// Use accumulation to update src node embedding, grad = 1
merge(inputUpdateCounter, inputUpdates, srcNodes(i), 1, neule)
}
val iter = inputUpdateCounter.entrySet().iterator()
while (iter.hasNext) {
val entry = iter.next()
div(inputUpdates.get(entry.getKey.toInt), entry.getValue.toFloat)
}
(inputUpdates, null)
} else {
val inputUpdateCounter = new Int2IntOpenHashMap(srcFeats.size())
val inputUpdates = new Int2ObjectOpenHashMap[Array[Float]](srcFeats.size())
val outputUpdateCounter = new Int2IntOpenHashMap(targetFeats.size())
val outputUpdates = new Int2ObjectOpenHashMap[Array[Float]](targetFeats.size())
var docIndex = 0
for (i <- 0 until srcNodes.length) {
// Src node grad
val neule = new Array[Float](dimension)
// Accumulate dst node embedding to neule
val dstEmbedding = targetFeats.get(destNodes(i))
var g = dots(docIndex)
axpy(neule, dstEmbedding, g)
// Use src node embedding to update dst node embedding
val srcEmbedding = srcFeats.get(srcNodes(i))
merge(outputUpdateCounter, outputUpdates, destNodes(i), g, srcEmbedding)
docIndex += 1
// Use src node embedding to update negative sample node embedding; Accumulate negative sample node embedding to neule
for (j <- 0 until negative) {
val negSampleEmbedding = targetFeats.get(negativeSamples(i)(j))
g = dots(docIndex)
// Accumulate negative sample node embedding to neule
axpy(neule, negSampleEmbedding, g)
// Use src node embedding to update negative sample node embedding
merge(outputUpdateCounter, outputUpdates, negativeSamples(i)(j), g, srcEmbedding)
docIndex += 1
}
// Use accumulation to update src node embedding, grad = 1
merge(inputUpdateCounter, inputUpdates, srcNodes(i), 1, neule)
}
var iter = inputUpdateCounter.int2IntEntrySet().fastIterator()
while (iter.hasNext) {
val entry = iter.next()
div(inputUpdates.get(entry.getIntKey), entry.getIntValue.toFloat)
}
iter = outputUpdateCounter.int2IntEntrySet().fastIterator()
while (iter.hasNext) {
val entry = iter.next()
div(outputUpdates.get(entry.getIntKey), entry.getIntValue.toFloat)
}
(inputUpdates, outputUpdates)
}
}
def merge(inputUpdateCounter: Int2IntOpenHashMap, inputUpdates: Int2ObjectOpenHashMap[Array[Float]],
nodeId: Int, g: Float, update: Array[Float]) = {
var grads: Array[Float] = inputUpdates.get(nodeId)
if (grads == null) {
grads = new Array[Float](dimension)
inputUpdates.put(nodeId, grads)
inputUpdateCounter.put(nodeId, 0)
}
//grads.iaxpy(update, g)
axpy(grads, update, g)
inputUpdateCounter.addTo(nodeId, 1)
}
def checkpoint(checkpointId:Int): Unit = {
val saveContext = new ModelSaveContext()
saveContext.addMatrix(new MatrixSaveContext(matrixName, classOf[SnapshotFormat].getTypeName))
PSContext.instance().checkpoint(checkpointId, saveContext)
}
def save(modelPathRoot: String, epoch: Int): Unit = {
save(new Path(modelPathRoot, s"CP_$epoch").toString)
}
def save(modelPath: String): Unit = {
logTime(s"saving model to $modelPath")
val ss = SparkSession.builder().getOrCreate()
deleteIfExists(modelPath, ss)
val saveContext = new ModelSaveContext(modelPath)
saveContext.addMatrix(new MatrixSaveContext(matrixName, classOf[TextLINEModelOutputFormat].getTypeName))
PSContext.instance().save(saveContext)
}
def destroy(): Unit ={
psMatrix.destroy()
}
def load(modelPath: String): Unit = {
val startTime = System.currentTimeMillis()
logTime(s"load model from $modelPath")
val loadContext = new ModelLoadContext(modelPath)
loadContext.addMatrix(new MatrixLoadContext(matrixName))
PSContext.getOrCreate(SparkContext.getOrCreate()).load(loadContext)
logTime(s"model load time=${System.currentTimeMillis() - startTime} ms")
}
def negativeSample(nodeIds: Array[Int], sampleNum: Int, seed: Int) = {
//val seed = UUID.randomUUID().hashCode()
val rand = new Random(seed)
val sampleNodes = new Array[Array[Int]](nodeIds.length)
var nodeIndex: Int = 0
for (nodeId <- nodeIds) {
var sampleIndex: Int = 0
sampleNodes(nodeIndex) = new Array[Int](sampleNum)
while (sampleIndex < sampleNum) {
val target = rand.nextInt(numNode)
if (target != nodeId) {
sampleNodes(nodeIndex)(sampleIndex) = target
sampleIndex += 1
}
}
nodeIndex += 1
}
sampleNodes
}
private def getAvailableExecutorNum(ss: SparkSession): Int = {
math.max(ss.sparkContext.statusTracker.getExecutorInfos.length - 1, 1)
}
private def deleteIfExists(modelPath: String, ss: SparkSession): Unit = {
val path = new Path(modelPath)
val fs = path.getFileSystem(ss.sparkContext.hadoopConfiguration)
if (fs.exists(path)) {
fs.delete(path, true)
}
}
def logTime(msg: String): Unit = {
val time = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date)
println(s"[$time] $msg")
}
def doGrad(dots: Array[Float], negative: Int, alpha: Float): Float = {
var loss = 0.0f
for (i <- dots.indices) {
val prob = FastSigmoid.sigmoid(dots(i))
if (i % (negative + 1) == 0) {
dots(i) = alpha * (1 - prob)
loss -= FastSigmoid.log(prob)
} else {
dots(i) = -alpha * FastSigmoid.sigmoid(dots(i))
loss -= FastSigmoid.log(1 - prob)
}
}
loss
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy