All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.graph.models.embeddings.GraphVectorsImpl Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.graph.models.embeddings;

import lombok.AllArgsConstructor;
import lombok.NoArgsConstructor;
import org.deeplearning4j.graph.api.IGraph;
import org.deeplearning4j.graph.api.Vertex;
import org.deeplearning4j.graph.models.GraphVectors;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.common.primitives.Pair;

import java.util.Comparator;
import java.util.PriorityQueue;

@AllArgsConstructor
@NoArgsConstructor
public class GraphVectorsImpl implements GraphVectors {

    protected IGraph graph;
    protected GraphVectorLookupTable lookupTable;


    @Override
    public IGraph getGraph() {
        return graph;
    }

    @Override
    public int numVertices() {
        return lookupTable.getNumVertices();
    }

    @Override
    public int getVectorSize() {
        return lookupTable.vectorSize();
    }

    @Override
    public INDArray getVertexVector(Vertex vertex) {
        return lookupTable.getVector(vertex.vertexID());
    }

    @Override
    public INDArray getVertexVector(int vertexIdx) {
        return lookupTable.getVector(vertexIdx);
    }

    @Override
    public int[] verticesNearest(int vertexIdx, int top) {

        INDArray vec = lookupTable.getVector(vertexIdx).dup();
        double norm2 = vec.norm2Number().doubleValue();


        PriorityQueue> pq =
                        new PriorityQueue<>(lookupTable.getNumVertices(), new PairComparator());

        Level1 l1 = Nd4j.getBlasWrapper().level1();
        for (int i = 0; i < numVertices(); i++) {
            if (i == vertexIdx)
                continue;

            INDArray other = lookupTable.getVector(i);
            double cosineSim = l1.dot(vec.length(), 1.0, vec, other) / (norm2 * other.norm2Number().doubleValue());

            pq.add(new Pair<>(cosineSim, i));
        }

        int[] out = new int[top];
        for (int i = 0; i < top; i++) {
            out[i] = pq.remove().getSecond();
        }

        return out;
    }

    private static class PairComparator implements Comparator> {
        @Override
        public int compare(Pair o1, Pair o2) {
            return -Double.compare(o1.getFirst(), o2.getFirst());
        }
    }

    /**Returns the cosine similarity of the vector representations of two vertices in the graph
     * @return Cosine similarity of two vertices
     */
    @Override
    public double similarity(Vertex vertex1, Vertex vertex2) {
        return similarity(vertex1.vertexID(), vertex2.vertexID());
    }

    /**Returns the cosine similarity of the vector representations of two vertices in the graph,
     * given the indices of these verticies
     * @return Cosine similarity of two vertices
     */
    @Override
    public double similarity(int vertexIdx1, int vertexIdx2) {
        if (vertexIdx1 == vertexIdx2)
            return 1.0;

        INDArray vector = Transforms.unitVec(getVertexVector(vertexIdx1));
        INDArray vector2 = Transforms.unitVec(getVertexVector(vertexIdx2));
        return Nd4j.getBlasWrapper().dot(vector, vector2);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy