All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.graph.models.embeddings.InMemoryGraphLookupTable Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.graph.models.embeddings;

import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.graph.models.BinaryTree;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/** A standard in-memory implementation of a lookup table for vector representations of the vertices in a graph
 * @author Alex Black
 */
public class InMemoryGraphLookupTable implements GraphVectorLookupTable {

    protected int nVertices;
    protected int vectorSize;
    protected BinaryTree tree;
    protected INDArray vertexVectors;   //'input' vectors
    protected INDArray outWeights;      //'output' vectors. Specifically vectors for inner nodes in binary tree
    protected double learningRate;

    protected double[] expTable;
    protected static double MAX_EXP = 6;

    public InMemoryGraphLookupTable(int nVertices, int vectorSize, BinaryTree tree, double learningRate ){
        this.nVertices = nVertices;
        this.vectorSize = vectorSize;
        this.tree = tree;
        this.learningRate = learningRate;
        resetWeights();

        expTable = new double[1000];
        for (int i = 0; i < expTable.length; i++) {
            double tmp =   FastMath.exp((i / (double) expTable.length * 2 - 1) * MAX_EXP);
            expTable[i]  = tmp / (tmp + 1.0);
        }
    }

    public INDArray getVertexVectors(){
        return vertexVectors;
    }

    public INDArray getOutWeights(){
        return outWeights;
    }

    @Override
    public int vectorSize() {
        return vectorSize;
    }

    @Override
    public void resetWeights() {
        this.vertexVectors = Nd4j.rand(nVertices,vectorSize).subi(0.5).divi(vectorSize);
        this.outWeights = Nd4j.rand(nVertices-1,vectorSize).subi(0.5).divi(vectorSize); //Full binary tree with L leaves has L-1 inner nodes
    }

    @Override
    public void iterate(int first, int second) {
        //Get vectors and gradients
        //vecAndGrads[0][0] is vector of vertex(first); vecAndGrads[1][0] is corresponding gradient
        INDArray[][] vecAndGrads = vectorsAndGradients(first,second);

        Level1 l1 = Nd4j.getBlasWrapper().level1();
        for( int i=0; i
     * Specifically, out[0] are vectors, out[1] are gradients for the corresponding vectors
* out[0][0] is vector for first vertex; out[0][1] is gradient for this vertex vector
* out[0][i] (i>0) is the inner node vector along path to second vertex; out[1][i] is gradient for inner node vertex
* This design is used primarily to aid in testing (numerical gradient checks) * @param first first (input) vertex index * @param second second (output) vertex index */ public INDArray[][] vectorsAndGradients(int first, int second){ //Input vertex vector gradients are composed of the inner node gradients //Get vector for first vertex, as well as code for second: INDArray vec = vertexVectors.getRow(first); int codeLength = tree.getCodeLength(second); long code = tree.getCode(second); int[] innerNodesForVertex = tree.getPathInnerNodes(second); INDArray[][] out = new INDArray[2][innerNodesForVertex.length+1]; Level1 l1 = Nd4j.getBlasWrapper().level1(); INDArray accumError = Nd4j.create(vec.shape()); for( int i=0; i




© 2015 - 2025 Weber Informatics LLC | Privacy Policy