
org.deeplearning4j.graph.iterator.WeightedRandomWalkIterator Maven / Gradle / Ivy
The newest version!
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.graph.iterator;
import org.deeplearning4j.graph.api.Edge;
import org.deeplearning4j.graph.api.IGraph;
import org.deeplearning4j.graph.api.IVertexSequence;
import org.deeplearning4j.graph.api.NoEdgeHandling;
import org.deeplearning4j.graph.exception.NoEdgesException;
import org.deeplearning4j.graph.VertexSequence;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Random;
public class WeightedRandomWalkIterator implements GraphWalkIterator {
private final IGraph graph;
private final int walkLength;
private final NoEdgeHandling mode;
private final int firstVertex;
private final int lastVertex;
private int position;
private Random rng;
private int[] order;
public WeightedRandomWalkIterator(IGraph graph, int walkLength) {
this(graph, walkLength, System.currentTimeMillis(), NoEdgeHandling.EXCEPTION_ON_DISCONNECTED);
}
/**Construct a RandomWalkIterator for a given graph, with a specified walk length and random number generator seed.
* Uses {@code NoEdgeHandling.EXCEPTION_ON_DISCONNECTED} - hence exception will be thrown when generating random
* walks on graphs with vertices containing having no edges, or no outgoing edges (for directed graphs)
* @see #WeightedRandomWalkIterator(IGraph, int, long, NoEdgeHandling)
*/
public WeightedRandomWalkIterator(IGraph graph, int walkLength, long rngSeed) {
this(graph, walkLength, rngSeed, NoEdgeHandling.EXCEPTION_ON_DISCONNECTED);
}
/**
* @param graph IGraph to conduct walks on
* @param walkLength length of each walk. Walk of length 0 includes 1 vertex, walk of 1 includes 2 vertices etc
* @param rngSeed seed for randomization
* @param mode mode for handling random walks from vertices with either no edges, or no outgoing edges (for directed graphs)
*/
public WeightedRandomWalkIterator(IGraph graph, int walkLength, long rngSeed,
NoEdgeHandling mode) {
this(graph, walkLength, rngSeed, mode, 0, graph.numVertices());
}
/**Constructor used to generate random walks starting at a subset of the vertices in the graph. Order of starting
* vertices is randomized within this subset
* @param graph IGraph to conduct walks on
* @param walkLength length of each walk. Walk of length 0 includes 1 vertex, walk of 1 includes 2 vertices etc
* @param rngSeed seed for randomization
* @param mode mode for handling random walks from vertices with either no edges, or no outgoing edges (for directed graphs)
* @param firstVertex first vertex index (inclusive) to start random walks from
* @param lastVertex last vertex index (exclusive) to start random walks from
*/
public WeightedRandomWalkIterator(IGraph graph, int walkLength, long rngSeed,
NoEdgeHandling mode, int firstVertex, int lastVertex) {
this.graph = graph;
this.walkLength = walkLength;
this.rng = new Random(rngSeed);
this.mode = mode;
this.firstVertex = firstVertex;
this.lastVertex = lastVertex;
order = new int[lastVertex - firstVertex];
for (int i = 0; i < order.length; i++)
order[i] = firstVertex + i;
reset();
}
@Override
public IVertexSequence next() {
if (!hasNext())
throw new NoSuchElementException();
//Generate a weighted random walk starting at vertex order[current]
int currVertexIdx = order[position++];
int[] indices = new int[walkLength + 1];
indices[0] = currVertexIdx;
if (walkLength == 0)
return new VertexSequence<>(graph, indices);
for (int i = 1; i <= walkLength; i++) {
List extends Edge extends Number>> edgeList = graph.getEdgesOut(currVertexIdx);
//First: check if there are any outgoing edges from this vertex. If not: handle the situation
if (edgeList == null || edgeList.isEmpty()) {
switch (mode) {
case SELF_LOOP_ON_DISCONNECTED:
for (int j = i; j < walkLength; j++)
indices[j] = currVertexIdx;
return new VertexSequence<>(graph, indices);
case EXCEPTION_ON_DISCONNECTED:
throw new NoEdgesException("Cannot conduct random walk: vertex " + currVertexIdx
+ " has no outgoing edges. "
+ " Set NoEdgeHandling mode to NoEdgeHandlingMode.SELF_LOOP_ON_DISCONNECTED to self loop instead of "
+ "throwing an exception in this situation.");
default:
throw new RuntimeException("Unknown/not implemented NoEdgeHandling mode: " + mode);
}
}
//To do a weighted random walk: we need to know total weight of all outgoing edges
double totalWeight = 0.0;
for (Edge extends Number> edge : edgeList) {
totalWeight += edge.getValue().doubleValue();
}
double d = rng.nextDouble();
double threshold = d * totalWeight;
double sumWeight = 0.0;
for (Edge extends Number> edge : edgeList) {
sumWeight += edge.getValue().doubleValue();
if (sumWeight >= threshold) {
if (edge.isDirected()) {
currVertexIdx = edge.getTo();
} else {
if (edge.getFrom() == currVertexIdx) {
currVertexIdx = edge.getTo();
} else {
currVertexIdx = edge.getFrom(); //Undirected edge: might be next--currVertexIdx instead of currVertexIdx--next
}
}
indices[i] = currVertexIdx;
break;
}
}
}
return new VertexSequence<>(graph, indices);
}
@Override
public boolean hasNext() {
return position < order.length;
}
@Override
public void reset() {
position = 0;
//https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm
for (int i = order.length - 1; i > 0; i--) {
int j = rng.nextInt(i + 1);
int temp = order[j];
order[j] = order[i];
order[i] = temp;
}
}
@Override
public int walkLength() {
return walkLength;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy