info.debatty.spark.knngraphs.builder.NNDescent Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of spark-knn-graphs Show documentation
Show all versions of spark-knn-graphs Show documentation
Spark algorithms for building k-nn graphs
package info.debatty.spark.knngraphs.builder;
import info.debatty.java.graphs.Neighbor;
import info.debatty.java.graphs.NeighborList;
import info.debatty.java.graphs.Node;
import java.io.Serializable;
import java.security.InvalidParameterException;
import java.util.ArrayList;
import java.util.Random;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import scala.Tuple2;
/**
* Implementation of NN-Descent k-nn graph building algorithm.
* Based on the paper "Efficient K-Nearest Neighbor Graph Construction for Generic Similarity Measures"
* by Dong et al.
* http://www.cs.princeton.edu/cass/papers/www11.pdf
*
* NN-Descent works by iteratively exploring the neighbors of neighbors...
*
* @author Thibault Debatty
* @param The class of nodes value
*/
public class NNDescent extends DistributedGraphBuilder implements Serializable {
private int max_iterations = 10;
/**
* Set the maximum number of iterations.
* Default value is 10
* @param max_iterations
* @return
*/
public NNDescent setMaxIterations(int max_iterations) {
if (max_iterations <= 0) {
throw new InvalidParameterException("max_iterations must be positive!");
}
this.max_iterations = max_iterations;
return this;
}
/**
*
* @param nodes
* @return
*/
protected JavaPairRDD, NeighborList> _computeGraph(JavaRDD> nodes) {
// Randomize: associate each node to 10 buckets out of 20
JavaPairRDD> randomized = nodes.flatMapToPair(
new PairFlatMapFunction, Integer, Node>() {
Random rand = new Random();
public Iterable>> call(Node n) throws Exception {
ArrayList>> r = new ArrayList>>();
for (int i = 0; i < 10; i++) {
r.add(new Tuple2>(rand.nextInt(20), n));
}
return r;
}
});
// Inside bucket, associate
JavaPairRDD, NeighborList> random_nl = randomized.groupByKey().flatMapToPair(
new PairFlatMapFunction>>, Node, NeighborList>() {
Random rand = new Random();
public Iterable, NeighborList>> call(
Tuple2>> tuple) throws Exception {
// Read all tuples in bucket
ArrayList> nodes = new ArrayList>();
for (Node n : tuple._2) {
nodes.add(n);
}
ArrayList, NeighborList>> r = new ArrayList, NeighborList>>();
for (Node n : nodes) {
NeighborList nnl = new NeighborList(k);
for (int i = 0; i < k; i++) {
nnl.add(new Neighbor(
nodes.get(rand.nextInt(nodes.size())),
Double.MAX_VALUE));
}
r.add(new Tuple2, NeighborList>(n, nnl));
}
return r;
}
});
// Merge
JavaPairRDD, NeighborList> graph = random_nl.reduceByKey(
new Function2() {
public NeighborList call(NeighborList nl1, NeighborList nl2) throws Exception {
NeighborList nnl = new NeighborList(k);
nnl.addAll(nl1);
nnl.addAll(nl2);
return nnl;
}
});
for (int iteration = 0; iteration < max_iterations; iteration++) {
// Reverse
JavaPairRDD, Node> exploded_graph = graph.flatMapToPair(
new PairFlatMapFunction, NeighborList>, Node, Node>() {
public Iterable, Node>> call(Tuple2, NeighborList> tuple) throws Exception {
ArrayList, Node>> r = new ArrayList, Node>>();
for (Neighbor neighbor : tuple._2()) {
r.add(new Tuple2, Node>(tuple._1(), neighbor.node));
r.add(new Tuple2, Node>(neighbor.node, tuple._1()));
}
return r;
}
});
//
graph = exploded_graph.groupByKey().flatMapToPair(
new PairFlatMapFunction, Iterable>>, Node, NeighborList>() {
public Iterable, NeighborList>> call(Tuple2, Iterable>> tuple) throws Exception {
// Fetch all nodes
ArrayList> nodes = new ArrayList>();
nodes.add(tuple._1);
for (Node n : tuple._2) {
nodes.add(n);
}
//
ArrayList, NeighborList>> r = new ArrayList, NeighborList>>(nodes.size());
for (Node n : nodes) {
NeighborList nl = new NeighborList(k);
for (Node other : nodes) {
if (other.equals(n)) {
continue;
}
nl.add(new Neighbor(
other,
similarity.similarity(n.value, other.value)));
}
r.add(new Tuple2, NeighborList>(n, nl));
}
return r;
}
});
// Filter
graph = graph.groupByKey().mapToPair(new PairFunction, Iterable>, Node, NeighborList>() {
public Tuple2, NeighborList> call(Tuple2, Iterable> tuple) throws Exception {
NeighborList nl = new NeighborList(k);
for (NeighborList other : tuple._2()) {
nl.addAll(other);
}
return new Tuple2, NeighborList>(tuple._1, nl);
}
});
}
return graph;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy