
org.deeplearning4j.graph.models.embeddings.InMemoryGraphLookupTable Maven / Gradle / Ivy
package org.deeplearning4j.graph.models.embeddings;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.graph.models.BinaryTree;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
/** A standard in-memory implementation of a lookup table for vector representations of the vertices in a graph
* @author Alex Black
*/
public class InMemoryGraphLookupTable implements GraphVectorLookupTable {
protected int nVertices;
protected int vectorSize;
protected BinaryTree tree;
protected INDArray vertexVectors; //'input' vectors
protected INDArray outWeights; //'output' vectors. Specifically vectors for inner nodes in binary tree
protected double learningRate;
protected double[] expTable;
protected static double MAX_EXP = 6;
public InMemoryGraphLookupTable(int nVertices, int vectorSize, BinaryTree tree, double learningRate ){
this.nVertices = nVertices;
this.vectorSize = vectorSize;
this.tree = tree;
this.learningRate = learningRate;
resetWeights();
expTable = new double[1000];
for (int i = 0; i < expTable.length; i++) {
double tmp = FastMath.exp((i / (double) expTable.length * 2 - 1) * MAX_EXP);
expTable[i] = tmp / (tmp + 1.0);
}
}
public INDArray getVertexVectors(){
return vertexVectors;
}
public INDArray getOutWeights(){
return outWeights;
}
@Override
public int vectorSize() {
return vectorSize;
}
@Override
public void resetWeights() {
this.vertexVectors = Nd4j.rand(nVertices,vectorSize).subi(0.5).divi(vectorSize);
this.outWeights = Nd4j.rand(nVertices-1,vectorSize).subi(0.5).divi(vectorSize); //Full binary tree with L leaves has L-1 inner nodes
}
@Override
public void iterate(int first, int second) {
//Get vectors and gradients
//vecAndGrads[0][0] is vector of vertex(first); vecAndGrads[1][0] is corresponding gradient
INDArray[][] vecAndGrads = vectorsAndGradients(first,second);
Level1 l1 = Nd4j.getBlasWrapper().level1();
for( int i=0; i
* Specifically, out[0] are vectors, out[1] are gradients for the corresponding vectors
* out[0][0] is vector for first vertex; out[0][1] is gradient for this vertex vector
* out[0][i] (i>0) is the inner node vector along path to second vertex; out[1][i] is gradient for inner node vertex
* This design is used primarily to aid in testing (numerical gradient checks)
* @param first first (input) vertex index
* @param second second (output) vertex index
*/
public INDArray[][] vectorsAndGradients(int first, int second){
//Input vertex vector gradients are composed of the inner node gradients
//Get vector for first vertex, as well as code for second:
INDArray vec = vertexVectors.getRow(first);
int codeLength = tree.getCodeLength(second);
long code = tree.getCode(second);
int[] innerNodesForVertex = tree.getPathInnerNodes(second);
INDArray[][] out = new INDArray[2][innerNodesForVertex.length+1];
Level1 l1 = Nd4j.getBlasWrapper().level1();
INDArray accumError = Nd4j.create(vec.shape());
for( int i=0; i
© 2015 - 2025 Weber Informatics LLC | Privacy Policy