com.tencent.angel.spark.ml.psf.embedding.bad.W2VPullParam Maven / Gradle / Ivy
package com.tencent.angel.spark.ml.psf.embedding.bad;
import com.tencent.angel.PartitionKey;
import com.tencent.angel.ml.matrix.psf.get.base.GetParam;
import com.tencent.angel.ml.matrix.psf.get.base.PartitionGetParam;
import com.tencent.angel.psagent.PSAgentContext;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class W2VPullParam extends GetParam {
int[] indices;
int numNodePerRow;
int dimension;
public W2VPullParam(int matrixId, int[] indices, int numNodePerRow, int dimension) {
super(matrixId);
this.indices = indices;
this.numNodePerRow = numNodePerRow;
this.dimension = dimension;
}
@Override
public List split() {
Arrays.sort(indices);
List pkeys = PSAgentContext.get().getMatrixMetaManager().getPartitions(matrixId);
List params = new ArrayList<>();
int start = 0, end = 0;
for (PartitionKey pkey : pkeys) {
int startRow = pkey.getStartRow();
int endRow = pkey.getEndRow();
int startNode = startRow * numNodePerRow;
int endNode = endRow * numNodePerRow;
if (start < indices.length && indices[start] >= startNode) {
while (end < indices.length && indices[end] < endNode)
end++;
if (end > start) {
params.add(new W2VPullParatitionParam(matrixId,
pkey,
indices,
numNodePerRow,
start,
end - start,
dimension));
}
start = end;
}
}
return params;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy