com.tencent.angel.sona.graph.embedding.line2.LINEGetEmbedding.scala Maven / Gradle / Ivy
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/
package com.tencent.angel.sona.graph.embedding.line2
import java.util
import com.tencent.angel.PartitionKey
import com.tencent.angel.graph.data.NodeUtils
import com.tencent.angel.ml.matrix.psf.get.base._
import com.tencent.angel.ps.storage.partition.RowBasedPartition
import com.tencent.angel.ps.storage.vector.ServerIntAnyRow
import com.tencent.angel.psagent.PSAgentContext
import com.tencent.angel.psagent.matrix.oplog.cache.RowUpdateSplitUtils
import io.netty.buffer.ByteBuf
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap
import scala.collection.JavaConversions._
class LINEGetEmbedding(param: LINEGetEmbeddingParam) extends GetFunc(param) {
def this() = this(null)
/**
* Partition get. This function is called on PS.
*
* @param partParam the partition parameter
* @return the partition result
*/
override def partitionGet(partParam: PartitionGetParam): PartitionGetResult = {
val getEmbeddingParam = partParam.asInstanceOf[PartLINEGetEmbeddingParam]
val matrix = psContext.getMatrixStorageManager.getMatrix(getEmbeddingParam.getMatrixId)
val part = matrix.getPartition(getEmbeddingParam.getPartKey.getPartitionId)
val row = part.asInstanceOf[RowBasedPartition].getRow(0).asInstanceOf[ServerIntAnyRow]
val srcNodeIds = getEmbeddingParam.srcNodeIds
val targetNodeIds = getEmbeddingParam.targetNodeIds
val order = getEmbeddingParam.order
// Get the number of nodes that need get feats
var srcFeats:Int2ObjectOpenHashMap[Array[Float]] = null
var targetFeats:Int2ObjectOpenHashMap[Array[Float]] = null
if (srcNodeIds != null) {
srcFeats = new Int2ObjectOpenHashMap[Array[Float]](srcNodeIds.length)
}
if(targetNodeIds != null) {
targetFeats = new Int2ObjectOpenHashMap[Array[Float]](targetNodeIds.length)
}
// Get feats for source nodes
if (srcNodeIds != null) {
for (nodeId <- srcNodeIds) {
srcFeats.put(nodeId, row.get(nodeId).asInstanceOf[LINENode].getInputFeats)
}
}
// Get feats for target nodes(dest nodes and negative sample nodes)
// We use srcFeats to store all node(src nodes and target nodes) features in order == 1
if (order == 1) {
// If order == 1, just get from input feats of LINENode
if (targetNodeIds != null) {
for (nodeId <- targetNodeIds) {
srcFeats.put(nodeId, row.get(nodeId).asInstanceOf[LINENode].getInputFeats)
}
}
} else {
// If order == 2, just get from output feats of LINENode
// Use targetFeats to store target node features
if (targetNodeIds != null) {
for (nodeId <- targetNodeIds) {
targetFeats.put(nodeId, row.get(nodeId).asInstanceOf[LINENode].getOutputFeats)
}
}
}
new PartLINEGetEmbeddingResult(getEmbeddingParam.getPartKey, srcFeats, targetFeats)
}
/**
* Merge the partition get results. This function is called on PSAgent.
*
* @param partResults the partition results
* @return the merged result
*/
override def merge(partResults: util.List[PartitionGetResult]): GetResult = {
val srcFeats = new Int2ObjectOpenHashMap[Array[Float]](param.srcNodeNum)
val targetFetas = new Int2ObjectOpenHashMap[Array[Float]](param.targetNodeNum)
for (partResult <- partResults) {
if(partResult.asInstanceOf[PartLINEGetEmbeddingResult].srcFeats != null) {
srcFeats.putAll(partResult.asInstanceOf[PartLINEGetEmbeddingResult].srcFeats)
}
if(partResult.asInstanceOf[PartLINEGetEmbeddingResult].targetFeats != null) {
targetFetas.putAll(partResult.asInstanceOf[PartLINEGetEmbeddingResult].targetFeats)
}
}
new LINEGetEmbeddingResult(srcFeats, targetFetas)
}
}
class LINEGetEmbeddingResult(srcFeats: Int2ObjectOpenHashMap[Array[Float]],
targetFeats:Int2ObjectOpenHashMap[Array[Float]]) extends GetResult {
def getResult = (srcFeats, targetFeats)
}
class LINEGetEmbeddingParam(matrixId: Int, srcNodes: Array[Int], dstNodes: Array[Int],
negativeSamples: Array[Array[Int]], order: Int, negative: Int) extends GetParam(matrixId) {
var srcNodeNum = 0
var targetNodeNum = 0
/**
* Split list.
*
* @return the list
*/
override def split(): util.List[PartitionGetParam] = {
srcNodeNum = srcNodes.length
// Merge the dest nodes and negative sample nodes
var offset: Int = 0
val targetNodeIds = new Array[Int](dstNodes.length + negative * negativeSamples.length)
Array.copy(dstNodes, 0, targetNodeIds, offset, dstNodes.length)
offset += dstNodes.length
for (i <- 0 until negativeSamples.length) {
Array.copy(negativeSamples(i), 0, targetNodeIds, offset, negativeSamples(i).length)
offset += negativeSamples(i).length
}
targetNodeNum = targetNodeIds.length
val parts = PSAgentContext.get().getMatrixMetaManager.getPartitions(matrixId, 0)
// Sort and split the node ids
val srcIndicesViews = RowUpdateSplitUtils.split(srcNodes.clone(), parts, false)
val targetIndicesViews = RowUpdateSplitUtils.split(targetNodeIds, parts, false)
val partToParams = new util.HashMap[PartitionKey, PartLINEGetEmbeddingParam](targetIndicesViews.size())
// Merge the src node splits and target node splits
val srcIter = srcIndicesViews.entrySet().iterator()
while (srcIter.hasNext) {
val entry = srcIter.next()
partToParams.put(entry.getKey, new PartLINEGetEmbeddingParam(matrixId, entry.getKey,
entry.getValue.getIndices, entry.getValue.getStart, entry.getValue.getEnd,
null, -1, -1, order))
}
val targetIter = targetIndicesViews.entrySet().iterator()
while (targetIter.hasNext) {
val entry = targetIter.next()
val partParam = partToParams.get(entry.getKey)
if (partParam == null) {
partToParams.put(entry.getKey, new PartLINEGetEmbeddingParam(matrixId, entry.getKey,
null, -1, -1,
entry.getValue.getIndices, entry.getValue.getStart, entry.getValue.getEnd,
order))
} else {
partParam.targetNodeIds = entry.getValue.getIndices
partParam.targetStart = entry.getValue.getStart
partParam.targetEnd = entry.getValue.getEnd
}
}
val partParams = new util.ArrayList[PartitionGetParam]
partParams.addAll(partToParams.values())
partParams
}
}
class PartLINEGetEmbeddingParam(matrixId: Int, part: PartitionKey, var srcNodeIds: Array[Int],
var srcStart: Int, var srcEnd: Int, var targetNodeIds: Array[Int],
var targetStart: Int, var targetEnd: Int, var order: Int) extends PartitionGetParam(matrixId, part) {
def this() = this(-1, null, null, -1, -1, null, -1, -1, -1)
override def serialize(buf: ByteBuf): Unit = {
super.serialize(buf)
// Src nodes
if (srcNodeIds != null) {
buf.writeInt(srcEnd - srcStart)
for (i <- srcStart until srcEnd) {
buf.writeInt(srcNodeIds(i))
}
} else {
buf.writeInt(0)
}
// Target nodes
if (targetNodeIds != null) {
buf.writeInt(targetEnd - targetStart)
for (i <- targetStart until targetEnd) {
buf.writeInt(targetNodeIds(i))
}
} else {
buf.writeInt(0)
}
// Order
buf.writeInt(order)
}
override def deserialize(buf: ByteBuf): Unit = {
super.deserialize(buf)
// Src node
val srcNodeNum = buf.readInt()
if (srcNodeNum > 0) {
srcNodeIds = new Array[Int](srcNodeNum)
for (i <- 0 until srcNodeNum) {
srcNodeIds(i) = buf.readInt()
}
}
// Target node
val targetNodeNum = buf.readInt()
if (targetNodeNum > 0) {
targetNodeIds = new Array[Int](targetNodeNum)
for (i <- 0 until targetNodeNum) {
targetNodeIds(i) = buf.readInt()
}
}
// Order
order = buf.readInt()
}
override def bufferLen(): Int = {
super.bufferLen() + 4 + (srcEnd - srcStart) * 4 + 4 + (targetEnd - targetStart) * 4 + 4
}
}
class PartLINEGetEmbeddingResult(var part: PartitionKey, var srcFeats: Int2ObjectOpenHashMap[Array[Float]],
var targetFeats: Int2ObjectOpenHashMap[Array[Float]]) extends PartitionGetResult {
def this() = this(null, null, null)
/**
* Serialize object to the Output stream.
*
* @param output the Netty ByteBuf
*/
override def serialize(output: ByteBuf): Unit = {
if(srcFeats != null) {
output.writeInt(srcFeats.size())
val resIter = srcFeats.int2ObjectEntrySet().fastIterator()
while(resIter.hasNext) {
val entry = resIter.next()
output.writeInt(entry.getIntKey)
NodeUtils.serialize(entry.getValue, output)
}
} else {
output.writeInt(0);
}
if(targetFeats != null) {
output.writeInt(targetFeats.size())
val resIter = targetFeats.int2ObjectEntrySet().fastIterator()
while(resIter.hasNext) {
val entry = resIter.next()
output.writeInt(entry.getIntKey)
NodeUtils.serialize(entry.getValue, output)
}
} else {
output.writeInt(0)
}
}
/**
* Deserialize object from the input stream.
*
* @param input the input stream
*/
override def deserialize(input: ByteBuf): Unit = {
var len = input.readInt()
if(len > 0) {
srcFeats = new Int2ObjectOpenHashMap[Array[Float]](len)
for (i <- 0 until len) {
srcFeats.put(input.readInt(), NodeUtils.deserializeFloats(input))
}
}
len = input.readInt()
if(len > 0) {
targetFeats = new Int2ObjectOpenHashMap[Array[Float]](len)
for (i <- 0 until len) {
targetFeats.put(input.readInt(), NodeUtils.deserializeFloats(input))
}
}
}
/**
* Estimate serialized data size of the object, it used to ByteBuf allocation.
*
* @return int serialized data size of the object
*/
override def bufferLen(): Int = {
var len = 8
var elemLen = 0
if(srcFeats != null) {
val resIter = srcFeats.int2ObjectEntrySet().fastIterator()
var break = false
while(resIter.hasNext && !break) {
val entry = resIter.next()
elemLen = 4 + NodeUtils.dataLen(entry.getValue)
break = true
}
len += elemLen * srcFeats.size()
}
if(targetFeats != null) {
val resIter = targetFeats.int2ObjectEntrySet().fastIterator()
var break = false
while(resIter.hasNext && !break) {
val entry = resIter.next()
elemLen = 4 + NodeUtils.dataLen(entry.getValue)
break = true
}
len += elemLen * targetFeats.size()
}
len
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy