
org.deeplearning4j.graph.models.deepwalk.DeepWalk Maven / Gradle / Ivy
The newest version!
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.graph.models.deepwalk;
import lombok.AllArgsConstructor;
import org.deeplearning4j.graph.api.IGraph;
import org.deeplearning4j.graph.api.IVertexSequence;
import org.deeplearning4j.graph.api.NoEdgeHandling;
import org.deeplearning4j.graph.iterator.GraphWalkIterator;
import org.deeplearning4j.graph.iterator.parallel.GraphWalkIteratorProvider;
import org.deeplearning4j.graph.iterator.parallel.RandomWalkGraphIteratorProvider;
import org.deeplearning4j.graph.models.embeddings.GraphVectorLookupTable;
import org.deeplearning4j.graph.models.embeddings.GraphVectorsImpl;
import org.deeplearning4j.graph.models.embeddings.InMemoryGraphLookupTable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.threadly.concurrent.PriorityScheduler;
import org.threadly.concurrent.future.FutureUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicLong;
public class DeepWalk extends GraphVectorsImpl {
public static final int STATUS_UPDATE_FREQUENCY = 1000;
private Logger log = LoggerFactory.getLogger(DeepWalk.class);
private int vectorSize;
private int windowSize;
private double learningRate;
private boolean initCalled = false;
private long seed;
private int nThreads = Runtime.getRuntime().availableProcessors();
private transient AtomicLong walkCounter = new AtomicLong(0);
public DeepWalk() {
}
public int getVectorSize() {
return vectorSize;
}
public int getWindowSize() {
return windowSize;
}
public double getLearningRate() {
return learningRate;
}
public void setLearningRate(double learningRate) {
this.learningRate = learningRate;
if (lookupTable != null)
lookupTable.setLearningRate(learningRate);
}
/** Initialize the DeepWalk model with a given graph. */
public void initialize(IGraph graph) {
int nVertices = graph.numVertices();
int[] degrees = new int[nVertices];
for (int i = 0; i < nVertices; i++)
degrees[i] = graph.getVertexDegree(i);
initialize(degrees);
}
/** Initialize the DeepWalk model with a list of vertex degrees for a graph.
* Specifically, graphVertexDegrees[i] represents the vertex degree of the ith vertex
* vertex degrees are used to construct a binary (Huffman) tree, which is in turn used in
* the hierarchical softmax implementation
* @param graphVertexDegrees degrees of each vertex
*/
public void initialize(int[] graphVertexDegrees) {
log.info("Initializing: Creating Huffman tree and lookup table...");
GraphHuffman gh = new GraphHuffman(graphVertexDegrees.length);
gh.buildTree(graphVertexDegrees);
lookupTable = new InMemoryGraphLookupTable(graphVertexDegrees.length, vectorSize, gh, learningRate);
initCalled = true;
log.info("Initialization complete");
}
/** Fit the model, in parallel.
* This creates a set of GraphWalkIterators, which are then distributed one to each thread
* @param graph Graph to fit
* @param walkLength Length of rangom walks to generate
*/
public void fit(IGraph graph, int walkLength) {
if (!initCalled)
initialize(graph);
//First: create iterators, one for each thread
GraphWalkIteratorProvider iteratorProvider = new RandomWalkGraphIteratorProvider<>(graph, walkLength, seed,
NoEdgeHandling.SELF_LOOP_ON_DISCONNECTED);
fit(iteratorProvider);
}
/** Fit the model, in parallel, using a given GraphWalkIteratorProvider.
* This object is used to generate multiple GraphWalkIterators, which can then be distributed to each thread
* to do in parallel
* Note that {@link #fit(IGraph, int)} will be more convenient in many cases
* Note that {@link #initialize(IGraph)} or {@link #initialize(int[])} must be called first.
* @param iteratorProvider GraphWalkIteratorProvider
* @see #fit(IGraph, int)
*/
public void fit(GraphWalkIteratorProvider iteratorProvider) {
if (!initCalled)
throw new UnsupportedOperationException("DeepWalk not initialized (call initialize before fit)");
List> iteratorList = iteratorProvider.getGraphWalkIterators(nThreads);
PriorityScheduler scheduler = new PriorityScheduler(nThreads);
List> list = new ArrayList<>(iteratorList.size());
//log.info("Fitting Graph with {} threads", Math.max(nThreads,iteratorList.size()));
for (GraphWalkIterator iter : iteratorList) {
LearningCallable c = new LearningCallable(iter);
list.add(scheduler.submit(c));
}
scheduler.shutdown(); // wont shutdown till complete
try {
FutureUtils.blockTillAllCompleteOrFirstError(list);
} catch (InterruptedException e) {
// should not be possible with blocking till scheduler terminates
Thread.currentThread().interrupt();
throw new RuntimeException(e);
} catch (ExecutionException e) {
throw new RuntimeException(e);
}
}
/**Fit the DeepWalk model using a single thread using a given GraphWalkIterator. If parallel fitting is required,
* {@link #fit(IGraph, int)} or {@link #fit(GraphWalkIteratorProvider)} should be used.
* Note that {@link #initialize(IGraph)} or {@link #initialize(int[])} must be called first.
*
* @param iterator iterator for graph walks
*/
public void fit(GraphWalkIterator iterator) {
if (!initCalled)
throw new UnsupportedOperationException("DeepWalk not initialized (call initialize before fit)");
int walkLength = iterator.walkLength();
while (iterator.hasNext()) {
IVertexSequence sequence = iterator.next();
//Skipgram model:
int[] walk = new int[walkLength + 1];
int i = 0;
while (sequence.hasNext())
walk[i++] = sequence.next().vertexID();
skipGram(walk);
long iter = walkCounter.incrementAndGet();
if (iter % STATUS_UPDATE_FREQUENCY == 0) {
log.info("Processed {} random walks on graph", iter);
}
}
}
private void skipGram(int[] walk) {
for (int mid = windowSize; mid < walk.length - windowSize; mid++) {
for (int pos = mid - windowSize; pos <= mid + windowSize; pos++) {
if (pos == mid)
continue;
//pair of vertices: walk[mid] -> walk[pos]
lookupTable.iterate(walk[mid], walk[pos]);
}
}
}
public GraphVectorLookupTable lookupTable() {
return lookupTable;
}
public static class Builder {
private int vectorSize = 100;
private long seed = System.currentTimeMillis();
private double learningRate = 0.01;
private int windowSize = 2;
/** Sets the size of the vectors to be learned for each vertex in the graph */
public Builder vectorSize(int vectorSize) {
this.vectorSize = vectorSize;
return this;
}
/** Set the learning rate */
public Builder learningRate(double learningRate) {
this.learningRate = learningRate;
return this;
}
/** Sets the window size used in skipgram model */
public Builder windowSize(int windowSize) {
this.windowSize = windowSize;
return this;
}
/** Seed for random number generation (used for repeatability).
* Note however that parallel/async gradient descent might result in behaviour that
* is not repeatable, in spite of setting seed
*/
public Builder seed(long seed) {
this.seed = seed;
return this;
}
public DeepWalk build() {
DeepWalk dw = new DeepWalk<>();
dw.vectorSize = vectorSize;
dw.windowSize = windowSize;
dw.learningRate = learningRate;
dw.seed = seed;
return dw;
}
}
@AllArgsConstructor
private class LearningCallable implements Callable {
private final GraphWalkIterator iterator;
@Override
public Void call() throws Exception {
fit(iterator);
return null;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy