org.deeplearning4j.models.node2vec.Node2Vec Maven / Gradle / Ivy
package org.deeplearning4j.models.node2vec;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm;
import org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
import org.deeplearning4j.models.sequencevectors.graph.primitives.Vertex;
import org.deeplearning4j.models.sequencevectors.graph.walkers.GraphWalker;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.interfaces.VectorsListener;
import org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.sequencevectors.transformers.impl.GraphTransformer;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Collection;
import java.util.List;
/**
* This is implementation for Node2Vec/DeepWalk for DeepLearning4J
*
* PLEASE NOTE: This class is under construction and isn't suited for any use.
*
* @author [email protected]
*/
@Slf4j
@Deprecated
public class Node2Vec extends SequenceVectors {
public INDArray inferVector(@NonNull Collection> vertices) {
return null;
}
public static class Builder extends SequenceVectors.Builder {
private GraphWalker walker;
public Builder(@NonNull GraphWalker walker, @NonNull VectorsConfiguration configuration) {
this.walker = walker;
this.configuration = configuration;
// FIXME: this will cause transformer initialization
GraphTransformer transformer = new GraphTransformer.Builder<>(walker.getSourceGraph())
.setGraphWalker(walker).shuffleOnReset(true).build();
this.iterator = new AbstractSequenceIterator.Builder(transformer).build();
}
@Override
protected Builder useExistingWordVectors(@NonNull WordVectors vec) {
super.useExistingWordVectors(vec);
return this;
}
@Override
public Builder iterate(@NonNull SequenceIterator iterator) {
super.iterate(iterator);
return this;
}
@Override
public Builder sequenceLearningAlgorithm(@NonNull String algoName) {
super.sequenceLearningAlgorithm(algoName);
return this;
}
@Override
public Builder sequenceLearningAlgorithm(@NonNull SequenceLearningAlgorithm algorithm) {
super.sequenceLearningAlgorithm(algorithm);
return this;
}
@Override
public Builder elementsLearningAlgorithm(@NonNull String algoName) {
super.elementsLearningAlgorithm(algoName);
return this;
}
@Override
public Builder elementsLearningAlgorithm(@NonNull ElementsLearningAlgorithm algorithm) {
super.elementsLearningAlgorithm(algorithm);
return this;
}
@Override
public Builder iterations(int iterations) {
super.iterations(iterations);
return this;
}
@Override
public Builder epochs(int numEpochs) {
super.epochs(numEpochs);
return this;
}
@Override
public Builder workers(int numWorkers) {
super.workers(numWorkers);
return this;
}
@Override
public Builder useHierarchicSoftmax(boolean reallyUse) {
super.useHierarchicSoftmax(reallyUse);
return this;
}
@Override
public Builder useAdaGrad(boolean reallyUse) {
super.useAdaGrad(reallyUse);
return this;
}
@Override
public Builder layerSize(int layerSize) {
super.layerSize(layerSize);
return this;
}
@Override
public Builder learningRate(double learningRate) {
super.learningRate(learningRate);
return this;
}
@Override
public Builder minWordFrequency(int minWordFrequency) {
super.minWordFrequency(minWordFrequency);
return this;
}
@Override
public Builder minLearningRate(double minLearningRate) {
super.minLearningRate(minLearningRate);
return this;
}
@Override
public Builder resetModel(boolean reallyReset) {
super.resetModel(reallyReset);
return this;
}
@Override
public Builder vocabCache(@NonNull VocabCache vocabCache) {
super.vocabCache(vocabCache);
return this;
}
@Override
public Builder lookupTable(@NonNull WeightLookupTable lookupTable) {
super.lookupTable(lookupTable);
return this;
}
@Override
public Builder sampling(double sampling) {
super.sampling(sampling);
return this;
}
@Override
public Builder negativeSample(double negative) {
super.negativeSample(negative);
return this;
}
@Override
public Builder stopWords(@NonNull List stopList) {
super.stopWords(stopList);
return this;
}
@Override
public Builder trainElementsRepresentation(boolean trainElements) {
super.trainElementsRepresentation(trainElements);
return this;
}
@Override
public Builder trainSequencesRepresentation(boolean trainSequences) {
super.trainSequencesRepresentation(trainSequences);
return this;
}
@Override
public Builder stopWords(@NonNull Collection stopList) {
super.stopWords(stopList);
return this;
}
@Override
public Builder windowSize(int windowSize) {
super.windowSize(windowSize);
return this;
}
@Override
public Builder seed(long randomSeed) {
super.seed(randomSeed);
return this;
}
@Override
public Builder modelUtils(@NonNull ModelUtils modelUtils) {
super.modelUtils(modelUtils);
return this;
}
@Override
public Builder useUnknown(boolean reallyUse) {
super.useUnknown(reallyUse);
return this;
}
@Override
public Builder unknownElement(@NonNull V element) {
super.unknownElement(element);
return this;
}
@Override
public Builder useVariableWindow(int... windows) {
super.useVariableWindow(windows);
return this;
}
@Override
public Builder usePreciseWeightInit(boolean reallyUse) {
super.usePreciseWeightInit(reallyUse);
return this;
}
@Override
protected void presetTables() {
super.presetTables();
}
@Override
public Builder setVectorsListeners(@NonNull Collection> vectorsListeners) {
super.setVectorsListeners(vectorsListeners);
return this;
}
@Override
public Builder enableScavenger(boolean reallyEnable) {
super.enableScavenger(reallyEnable);
return this;
}
public Node2Vec build() {
Node2Vec node2vec = new Node2Vec<>();
node2vec.iterator = this.iterator;
return node2vec;
}
}
}