All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.graph.models.deepwalk.DeepWalk Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.graph.models.deepwalk;

import lombok.AllArgsConstructor;
import org.deeplearning4j.graph.api.IGraph;
import org.deeplearning4j.graph.api.NoEdgeHandling;
import org.deeplearning4j.graph.api.Vertex;
import org.deeplearning4j.graph.api.IVertexSequence;
import org.deeplearning4j.graph.iterator.GraphWalkIterator;
import org.deeplearning4j.graph.iterator.parallel.GraphWalkIteratorProvider;
import org.deeplearning4j.graph.iterator.parallel.RandomWalkGraphIteratorProvider;
import org.deeplearning4j.graph.models.GraphVectors;
import org.deeplearning4j.graph.models.embeddings.GraphVectorLookupTable;
import org.deeplearning4j.graph.models.embeddings.GraphVectorsImpl;
import org.deeplearning4j.graph.models.embeddings.InMemoryGraphLookupTable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicLong;

/**Implementation of the DeepWalk graph vectorization model, based on the paper
 * DeepWalk: Online Learning of Social Representations by Perozzi, Al-Rfou & Skiena (2014),
 * http://arxiv.org/abs/1403.6652
* Similar to word2vec in nature, DeepWalk is an unsupervised learning algorithm that learns a vector representation * of each vertex in a graph. Vector representations are learned using walks (usually random walks) on the vertices in * the graph.
* Once learned, these vector representations can then be used for purposes such as classification, clustering, similarity * search, etc on the graph
* @author Alex Black */ public class DeepWalk extends GraphVectorsImpl { public static final int STATUS_UPDATE_FREQUENCY = 1000; private Logger log = LoggerFactory.getLogger(DeepWalk.class); private int vectorSize; private int windowSize; private double learningRate; private boolean initCalled = false; private long seed; private ExecutorService executorService; private int nThreads = Runtime.getRuntime().availableProcessors(); private transient AtomicLong walkCounter = new AtomicLong(0); public DeepWalk(){ } public int getVectorSize(){ return vectorSize; } public int getWindowSize(){ return windowSize; } public double getLearningRate(){ return learningRate; } public void setLearningRate(double learningRate){ this.learningRate = learningRate; if(lookupTable != null) lookupTable.setLearningRate(learningRate); } /** Initialize the DeepWalk model with a given graph. */ public void initialize(IGraph graph){ int nVertices = graph.numVertices(); int[] degrees = new int[nVertices]; for( int i=0; i * Specifically, graphVertexDegrees[i] represents the vertex degree of the ith vertex
* vertex degrees are used to construct a binary (Huffman) tree, which is in turn used in * the hierarchical softmax implementation * @param graphVertexDegrees degrees of each vertex */ public void initialize(int[] graphVertexDegrees){ log.info("Initializing: Creating Huffman tree and lookup table..."); GraphHuffman gh = new GraphHuffman(graphVertexDegrees.length); gh.buildTree(graphVertexDegrees); lookupTable = new InMemoryGraphLookupTable(graphVertexDegrees.length,vectorSize,gh,learningRate); initCalled = true; log.info("Initialization complete"); } /** Fit the model, in parallel. * This creates a set of GraphWalkIterators, which are then distributed one to each thread * @param graph Graph to fit * @param walkLength Length of rangom walks to generate */ public void fit( IGraph graph, int walkLength ){ if(!initCalled) initialize(graph); //First: create iterators, one for each thread GraphWalkIteratorProvider iteratorProvider = new RandomWalkGraphIteratorProvider(graph,walkLength,seed, NoEdgeHandling.SELF_LOOP_ON_DISCONNECTED); fit(iteratorProvider); } /** Fit the model, in parallel, using a given GraphWalkIteratorProvider.
* This object is used to generate multiple GraphWalkIterators, which can then be distributed to each thread * to do in parallel
* Note that {@link #fit(IGraph, int)} will be more convenient in many cases
* Note that {@link #initialize(IGraph)} or {@link #initialize(int[])} must be called first. * @param iteratorProvider GraphWalkIteratorProvider * @see #fit(IGraph, int) */ public void fit(GraphWalkIteratorProvider iteratorProvider){ if(!initCalled) throw new UnsupportedOperationException("DeepWalk not initialized (call initialize before fit)"); List> iteratorList = iteratorProvider.getGraphWalkIterators(nThreads); executorService = Executors.newFixedThreadPool(nThreads, new ThreadFactory() { @Override public Thread newThread(Runnable r) { Thread t = new Thread(r); t.setDaemon(true); return t; } }); List> list = new ArrayList<>(iteratorList.size()); //log.info("Fitting Graph with {} threads", Math.max(nThreads,iteratorList.size())); for( GraphWalkIterator iter : iteratorList ){ LearningCallable c = new LearningCallable(iter); list.add(executorService.submit(c)); } executorService.shutdown(); try{ executorService.awaitTermination(999, TimeUnit.DAYS); }catch(InterruptedException e){ throw new RuntimeException("ExecutorService interrupted",e); } //Don't need to block on futures to get a value out, but we want to re-throw any exceptions encountered for(Future f : list){ try{ f.get(); }catch(Exception e){ throw new RuntimeException(e); } } } /**Fit the DeepWalk model using a single thread using a given GraphWalkIterator. If parallel fitting is required, * {@link #fit(IGraph, int)} or {@link #fit(GraphWalkIteratorProvider)} should be used.
* Note that {@link #initialize(IGraph)} or {@link #initialize(int[])} must be called first. * * @param iterator iterator for graph walks */ public void fit(GraphWalkIterator iterator){ if(!initCalled) throw new UnsupportedOperationException("DeepWalk not initialized (call initialize before fit)"); int walkLength = iterator.walkLength(); while(iterator.hasNext()){ IVertexSequence sequence = iterator.next(); //Skipgram model: int[] walk = new int[walkLength+1]; int i=0; while(sequence.hasNext()) walk[i++] = sequence.next().vertexID(); skipGram(walk); long iter = walkCounter.incrementAndGet(); if(iter % STATUS_UPDATE_FREQUENCY == 0 ){ log.info("Processed {} random walks on graph",iter); } } } private void skipGram(int[] walk){ for(int mid = windowSize; mid < walk.length-windowSize; mid++ ){ for( int pos=mid-windowSize; pos<=mid+windowSize; pos++ ){ if(pos == mid) continue; //pair of vertices: walk[mid] -> walk[pos] lookupTable.iterate(walk[mid],walk[pos]); } } } public GraphVectorLookupTable lookupTable(){ return lookupTable; } public static class Builder { private int vectorSize = 100; private long seed = System.currentTimeMillis(); private double learningRate = 0.01; private int windowSize = 2; /** Sets the size of the vectors to be learned for each vertex in the graph */ public Builder vectorSize(int vectorSize){ this.vectorSize = vectorSize; return this; } /** Set the learning rate */ public Builder learningRate(double learningRate){ this.learningRate = learningRate; return this; } /** Sets the window size used in skipgram model */ public Builder windowSize(int windowSize){ this.windowSize = windowSize; return this; } /** Seed for random number generation (used for repeatability). * Note however that parallel/async gradient descent might result in behaviour that * is not repeatable, in spite of setting seed */ public Builder seed(long seed){ this.seed = seed; return this; } public DeepWalk build(){ DeepWalk dw = new DeepWalk<>(); dw.vectorSize = vectorSize; dw.windowSize = windowSize; dw.learningRate = learningRate; dw.seed = seed; return dw; } } @AllArgsConstructor private class LearningCallable implements Callable { private final GraphWalkIterator iterator; @Override public Void call() throws Exception { fit(iterator); return null; } } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy