
org.deeplearning4j.graph.models.embeddings.GraphVectorsImpl Maven / Gradle / Ivy
package org.deeplearning4j.graph.models.embeddings;
import lombok.AllArgsConstructor;
import lombok.NoArgsConstructor;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.graph.api.IGraph;
import org.deeplearning4j.graph.api.Vertex;
import org.deeplearning4j.graph.models.GraphVectors;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Collection;
import java.util.Comparator;
import java.util.PriorityQueue;
/** Base implementation for GraphVectors. Used in DeepWalk, and also when loading
* graph vectors from file.
*/
@AllArgsConstructor @NoArgsConstructor
public class GraphVectorsImpl implements GraphVectors {
protected IGraph graph;
protected GraphVectorLookupTable lookupTable;
@Override
public IGraph getGraph() {
return graph;
}
@Override
public int numVertices() {
return lookupTable.getNumVertices();
}
@Override
public int getVectorSize(){
return lookupTable.vectorSize();
}
@Override
public INDArray getVertexVector(Vertex vertex) {
return lookupTable.getVector(vertex.vertexID());
}
@Override
public INDArray getVertexVector(int vertexIdx) {
return lookupTable.getVector(vertexIdx);
}
@Override
public int[] verticesNearest(int vertexIdx, int top) {
INDArray vec = lookupTable.getVector(vertexIdx).dup();
double norm2 = vec.norm2Number().doubleValue();
PriorityQueue> pq = new PriorityQueue>(lookupTable.getNumVertices(),new PairComparator());
Level1 l1 = Nd4j.getBlasWrapper().level1();
for( int i=0; i(cosineSim,i));
}
int[] out = new int[top];
for( int i=0; i> {
@Override
public int compare(Pair o1, Pair o2) {
return -Double.compare(o1.getFirst(),o2.getFirst());
}
}
/**Returns the cosine similarity of the vector representations of two vertices in the graph
* @return Cosine similarity of two vertices
*/
@Override
public double similarity(Vertex vertex1, Vertex vertex2) {
return similarity(vertex1.vertexID(),vertex2.vertexID());
}
/**Returns the cosine similarity of the vector representations of two vertices in the graph,
* given the indices of these verticies
* @return Cosine similarity of two vertices
*/
@Override
public double similarity(int vertexIdx1, int vertexIdx2) {
if(vertexIdx1 == vertexIdx2) return 1.0;
INDArray vector = Transforms.unitVec(getVertexVector(vertexIdx1));
INDArray vector2 = Transforms.unitVec(getVertexVector(vertexIdx2));
return Nd4j.getBlasWrapper().dot(vector, vector2);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy