All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.models.node2vec.Node2Vec Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.models.node2vec;

import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm;
import org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
import org.deeplearning4j.models.sequencevectors.graph.primitives.Vertex;
import org.deeplearning4j.models.sequencevectors.graph.walkers.GraphWalker;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.interfaces.VectorsListener;
import org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.sequencevectors.transformers.impl.GraphTransformer;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.ndarray.INDArray;

import java.util.Collection;
import java.util.List;

/**
 * This is implementation for Node2Vec/DeepWalk for DeepLearning4J
 *
 * PLEASE NOTE: This class is under construction and isn't suited for any use.
 *
 * @author [email protected]
 */
@Slf4j
@Deprecated
public class Node2Vec extends SequenceVectors {

    public INDArray inferVector(@NonNull Collection> vertices) {
        return null;
    }

    public static class Builder extends SequenceVectors.Builder {
        private GraphWalker walker;

        public Builder(@NonNull GraphWalker walker, @NonNull VectorsConfiguration configuration) {
            this.walker = walker;
            this.configuration = configuration;

            // FIXME: this will cause transformer initialization
            GraphTransformer transformer = new GraphTransformer.Builder<>(walker.getSourceGraph())
                            .setGraphWalker(walker).shuffleOnReset(true).build();

            this.iterator = new AbstractSequenceIterator.Builder(transformer).build();
        }


        @Override
        protected Builder useExistingWordVectors(@NonNull WordVectors vec) {
            super.useExistingWordVectors(vec);
            return this;
        }

        @Override
        public Builder iterate(@NonNull SequenceIterator iterator) {
            super.iterate(iterator);
            return this;
        }

        @Override
        public Builder sequenceLearningAlgorithm(@NonNull String algoName) {
            super.sequenceLearningAlgorithm(algoName);
            return this;
        }

        @Override
        public Builder sequenceLearningAlgorithm(@NonNull SequenceLearningAlgorithm algorithm) {
            super.sequenceLearningAlgorithm(algorithm);
            return this;
        }

        @Override
        public Builder elementsLearningAlgorithm(@NonNull String algoName) {
            super.elementsLearningAlgorithm(algoName);
            return this;
        }

        @Override
        public Builder elementsLearningAlgorithm(@NonNull ElementsLearningAlgorithm algorithm) {
            super.elementsLearningAlgorithm(algorithm);
            return this;
        }

        @Override
        public Builder iterations(int iterations) {
            super.iterations(iterations);
            return this;
        }

        @Override
        public Builder epochs(int numEpochs) {
            super.epochs(numEpochs);
            return this;
        }

        @Override
        public Builder workers(int numWorkers) {
            super.workers(numWorkers);
            return this;
        }

        @Override
        public Builder useHierarchicSoftmax(boolean reallyUse) {
            super.useHierarchicSoftmax(reallyUse);
            return this;
        }

        @Override
        public Builder useAdaGrad(boolean reallyUse) {
            super.useAdaGrad(reallyUse);
            return this;
        }

        @Override
        public Builder layerSize(int layerSize) {
            super.layerSize(layerSize);
            return this;
        }

        @Override
        public Builder learningRate(double learningRate) {
            super.learningRate(learningRate);
            return this;
        }

        @Override
        public Builder minWordFrequency(int minWordFrequency) {
            super.minWordFrequency(minWordFrequency);
            return this;
        }

        @Override
        public Builder minLearningRate(double minLearningRate) {
            super.minLearningRate(minLearningRate);
            return this;
        }

        @Override
        public Builder resetModel(boolean reallyReset) {
            super.resetModel(reallyReset);
            return this;
        }

        @Override
        public Builder vocabCache(@NonNull VocabCache vocabCache) {
            super.vocabCache(vocabCache);
            return this;
        }

        @Override
        public Builder lookupTable(@NonNull WeightLookupTable lookupTable) {
            super.lookupTable(lookupTable);
            return this;
        }

        @Override
        public Builder sampling(double sampling) {
            super.sampling(sampling);
            return this;
        }

        @Override
        public Builder negativeSample(double negative) {
            super.negativeSample(negative);
            return this;
        }

        @Override
        public Builder stopWords(@NonNull List stopList) {
            super.stopWords(stopList);
            return this;
        }

        @Override
        public Builder trainElementsRepresentation(boolean trainElements) {
            super.trainElementsRepresentation(trainElements);
            return this;
        }

        @Override
        public Builder trainSequencesRepresentation(boolean trainSequences) {
            super.trainSequencesRepresentation(trainSequences);
            return this;
        }

        @Override
        public Builder stopWords(@NonNull Collection stopList) {
            super.stopWords(stopList);
            return this;
        }

        @Override
        public Builder windowSize(int windowSize) {
            super.windowSize(windowSize);
            return this;
        }

        @Override
        public Builder seed(long randomSeed) {
            super.seed(randomSeed);
            return this;
        }

        @Override
        public Builder modelUtils(@NonNull ModelUtils modelUtils) {
            super.modelUtils(modelUtils);
            return this;
        }

        @Override
        public Builder useUnknown(boolean reallyUse) {
            super.useUnknown(reallyUse);
            return this;
        }

        @Override
        public Builder unknownElement(@NonNull V element) {
            super.unknownElement(element);
            return this;
        }

        @Override
        public Builder useVariableWindow(int... windows) {
            super.useVariableWindow(windows);
            return this;
        }

        @Override
        public Builder usePreciseWeightInit(boolean reallyUse) {
            super.usePreciseWeightInit(reallyUse);
            return this;
        }

        @Override
        protected void presetTables() {
            super.presetTables();
        }

        @Override
        public Builder setVectorsListeners(@NonNull Collection> vectorsListeners) {
            super.setVectorsListeners(vectorsListeners);
            return this;
        }

        @Override
        public Builder enableScavenger(boolean reallyEnable) {
            super.enableScavenger(reallyEnable);
            return this;
        }

        public Node2Vec build() {
            Node2Vec node2vec = new Node2Vec<>();
            node2vec.iterator = this.iterator;

            return node2vec;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy