![JAR search and dependency download from the Maven repository](/logo.png)
com.tencent.angel.sona.graph.embedding.line2.LINEAdjust.scala Maven / Gradle / Ivy
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/
package com.tencent.angel.sona.graph.embedding.line2
import java.util
import com.tencent.angel.PartitionKey
import com.tencent.angel.graph.data.NodeUtils
import com.tencent.angel.ml.matrix.psf.update.base.{PartitionUpdateParam, UpdateFunc, UpdateParam}
import com.tencent.angel.ps.storage.partition.RowBasedPartition
import com.tencent.angel.ps.storage.vector.ServerIntAnyRow
import com.tencent.angel.psagent.PSAgentContext
import com.tencent.angel.psagent.matrix.oplog.cache.RowUpdateSplitUtils
import io.netty.buffer.ByteBuf
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap
class LINEAdjust(var param: LINEAdjustParam) extends UpdateFunc(param) {
def this() = this(null)
/**
* Partition update.
*
* @param partParam the partition parameter
*/
override def partitionUpdate(partParam: PartitionUpdateParam): Unit = {
val adjustParam = partParam.asInstanceOf[PartLINEAdjustParam]
val inputUpdates = adjustParam.inputUpdates
val outputUpdates = adjustParam.outputUpdates
val order = adjustParam.order
val matrix = psContext.getMatrixStorageManager.getMatrix(adjustParam.getMatrixId)
val part = matrix.getPartition(adjustParam.getPartKey.getPartitionId)
val row = part.asInstanceOf[RowBasedPartition].getRow(0).asInstanceOf[ServerIntAnyRow]
//row.startWrite()
try {
if (inputUpdates != null) {
val iter = inputUpdates.int2ObjectEntrySet().fastIterator()
while (iter.hasNext) {
val entry = iter.next()
inc(row.get(entry.getIntKey).asInstanceOf[LINENode].getInputFeats, entry.getValue)
}
}
if (order == 2 && outputUpdates != null) {
val iter = outputUpdates.int2ObjectEntrySet().fastIterator()
while (iter.hasNext) {
val entry = iter.next()
inc(row.get(entry.getIntKey).asInstanceOf[LINENode].getOutputFeats, entry.getValue)
}
}
} finally {
//row.endWrite()
}
}
def inc(dst: Array[Float], src: Array[Float]) = {
for (i <- 0 until dst.length) {
dst(i) += src(i)
}
}
}
class LINEAdjustParam(matrixId: Int, inputUpdates: Int2ObjectOpenHashMap[Array[Float]],
outputUpdates: Int2ObjectOpenHashMap[Array[Float]], order: Int) extends UpdateParam(matrixId) {
/**
* Split list.
*
* @return the list
*/
override def split(): util.List[PartitionUpdateParam] = {
val parts = PSAgentContext.get().getMatrixMetaManager.getPartitions(matrixId, 0)
// If order == 1, we just need split inputUpdates
if (order == 1) {
val nodeIds: Array[Int] = inputUpdates.keySet().toIntArray
val indicesViews = RowUpdateSplitUtils.split(nodeIds, parts, false)
val partParams = new util.ArrayList[PartitionUpdateParam](indicesViews.size())
val iter = indicesViews.entrySet().iterator()
while (iter.hasNext) {
val entry = iter.next()
partParams.add(new PartLINEAdjustParam(matrixId, entry.getKey, order, entry.getValue.getIndices,
entry.getValue.getStart, entry.getValue.getEnd, inputUpdates, null, -1, -1, null))
}
partParams
} else {
var partToParams: util.HashMap[PartitionKey, PartitionUpdateParam] = null
// Split output updaters first
if (outputUpdates != null && !outputUpdates.isEmpty) {
val outputNodeIds = outputUpdates.keySet().toIntArray()
val indicesViews = RowUpdateSplitUtils.split(outputNodeIds, parts, false)
partToParams = new util.HashMap[PartitionKey, PartitionUpdateParam](indicesViews.size())
val iter = indicesViews.entrySet().iterator()
while (iter.hasNext) {
val entry = iter.next()
partToParams.put(entry.getKey, new PartLINEAdjustParam(matrixId, entry.getKey, order,
null, -1, -1, null,
entry.getValue.getIndices, entry.getValue.getStart, entry.getValue.getEnd, outputUpdates))
}
}
// Merge input update splits
if (inputUpdates != null && !inputUpdates.isEmpty) {
val inputNodeIds = inputUpdates.keySet().toIntArray
val indicesViews = RowUpdateSplitUtils.split(inputNodeIds, parts, false)
val iter = indicesViews.entrySet().iterator()
while (iter.hasNext) {
val entry = iter.next()
val partParam = partToParams.get(entry.getKey)
if (partParam == null) {
partToParams.put(entry.getKey, new PartLINEAdjustParam(matrixId, entry.getKey, order, entry.getValue.getIndices,
entry.getValue.getStart, entry.getValue.getEnd, inputUpdates, null, -1, -1, null))
} else {
partParam.asInstanceOf[PartLINEAdjustParam].inputUpdates = inputUpdates
partParam.asInstanceOf[PartLINEAdjustParam].inputNodeIds = entry.getValue.getIndices
partParam.asInstanceOf[PartLINEAdjustParam].inputStart = entry.getValue.getStart
partParam.asInstanceOf[PartLINEAdjustParam].intputEnd = entry.getValue.getEnd
}
}
}
val partParams = new util.ArrayList[PartitionUpdateParam](partToParams.size())
partParams.addAll(partToParams.values())
partParams
}
}
}
class PartLINEAdjustParam(matrixId: Int, part: PartitionKey, var order: Int, var inputNodeIds: Array[Int], var inputStart: Int, var intputEnd: Int,
var inputUpdates: Int2ObjectOpenHashMap[Array[Float]],
var outputNodeIds: Array[Int], var outputStart: Int, var outputEnd: Int,
var outputUpdates: Int2ObjectOpenHashMap[Array[Float]]) extends PartitionUpdateParam(matrixId, part) {
def this() = this(-1, null, 1, null, -1, -1, null, null, -1, -1, null)
override def serialize(buf: ByteBuf): Unit = {
super.serialize(buf)
buf.writeInt(order)
if (inputNodeIds != null) {
//Size
buf.writeInt(intputEnd - inputStart)
// Node grads
for (i <- inputStart until intputEnd) {
// Node id
buf.writeInt(inputNodeIds(i))
// Node grads
NodeUtils.serialize(inputUpdates.get(inputNodeIds(i)), buf)
}
} else {
buf.writeInt(0)
}
if (outputNodeIds != null) {
//Size
buf.writeInt(outputEnd - outputStart)
// Node grads
for (i <- outputStart until outputEnd) {
// Node id
buf.writeInt(outputNodeIds(i))
// Node grads
NodeUtils.serialize(outputUpdates.get(outputNodeIds(i)), buf)
}
} else {
buf.writeInt(0)
}
}
override def deserialize(buf: ByteBuf): Unit = {
super.deserialize(buf)
order = buf.readInt()
// Node number
var nodeNum = buf.readInt()
if (nodeNum > 0) {
inputUpdates = new Int2ObjectOpenHashMap[Array[Float]](nodeNum)
for (i <- 0 until nodeNum) {
inputUpdates.put(buf.readInt(), NodeUtils.deserializeFloats(buf))
}
}
nodeNum = buf.readInt()
if (nodeNum > 0) {
outputUpdates = new Int2ObjectOpenHashMap[Array[Float]](nodeNum)
for (i <- 0 until nodeNum) {
outputUpdates.put(buf.readInt(), NodeUtils.deserializeFloats(buf))
}
}
}
override def bufferLen(): Int = {
var len = super.bufferLen()
len += 4
len += 4
if (inputNodeIds != null && inputNodeIds.length > 0) {
len += (intputEnd - inputStart) * (4 + NodeUtils.dataLen(inputUpdates.get(inputNodeIds(inputStart))))
}
if (outputNodeIds != null && outputNodeIds.length > 0) {
len += (outputEnd - outputStart) * (4 + NodeUtils.dataLen(outputUpdates.get(outputNodeIds(outputStart))))
}
len
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy