All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.graph.models.loader.GraphVectorSerializer Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.graph.models.loader;

import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.graph.models.GraphVectors;
import org.deeplearning4j.graph.models.deepwalk.DeepWalk;
import org.deeplearning4j.graph.models.embeddings.GraphVectorsImpl;
import org.deeplearning4j.graph.models.embeddings.InMemoryGraphLookupTable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.*;
import java.util.ArrayList;
import java.util.List;

public class GraphVectorSerializer {
    private static final Logger log = LoggerFactory.getLogger(GraphVectorSerializer.class);
    private static final String DELIM = "\t";

    private GraphVectorSerializer() {}

    public static void writeGraphVectors(DeepWalk deepWalk, String path) throws IOException {

        int nVertices = deepWalk.numVertices();
        int vectorSize = deepWalk.getVectorSize();

        try (BufferedWriter write = new BufferedWriter(new FileWriter(new File(path), false))) {
            for (int i = 0; i < nVertices; i++) {
                StringBuilder sb = new StringBuilder();
                sb.append(i);
                INDArray vec = deepWalk.getVertexVector(i);
                for (int j = 0; j < vectorSize; j++) {
                    double d = vec.getDouble(j);
                    sb.append(DELIM).append(d);
                }
                sb.append("\n");
                write.write(sb.toString());
            }
        }

        log.info("Wrote {} vectors of length {} to: {}", nVertices, vectorSize, path);
    }

    public static GraphVectors loadTxtVectors(File file) throws IOException {

        List vectorList = new ArrayList<>();

        try (BufferedReader reader = new BufferedReader(new FileReader(file))) {
            LineIterator iter = IOUtils.lineIterator(reader);

            while (iter.hasNext()) {
                String line = iter.next();
                String[] split = line.split(DELIM);
                double[] vec = new double[split.length - 1];
                for (int i = 1; i < split.length; i++) {
                    vec[i - 1] = Double.parseDouble(split[i]);
                }
                vectorList.add(vec);
            }
        }

        int vecSize = vectorList.get(0).length;
        int nVertices = vectorList.size();

        INDArray vectors = Nd4j.create(nVertices, vecSize);
        for (int i = 0; i < vectorList.size(); i++) {
            double[] vec = vectorList.get(i);
            for (int j = 0; j < vec.length; j++) {
                vectors.put(i, j, vec[j]);
            }
        }

        InMemoryGraphLookupTable table = new InMemoryGraphLookupTable(nVertices, vecSize, null, 0.01);
        table.setVertexVectors(vectors);

        return new GraphVectorsImpl<>(null, table);
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy