All Downloads are FREE. Search and download functionalities are using the official Maven repository.

info.debatty.spark.knngraphs.builder.NNDescent Maven / Gradle / Ivy

There is a newer version: 0.15
Show newest version
package info.debatty.spark.knngraphs.builder;


import info.debatty.java.graphs.Neighbor;
import info.debatty.java.graphs.NeighborList;
import info.debatty.java.graphs.Node;
import java.io.Serializable;
import java.security.InvalidParameterException;
import java.util.ArrayList;
import java.util.Random;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import scala.Tuple2;

/**
 * Implementation of NN-Descent k-nn graph building algorithm.
 * Based on the paper "Efficient K-Nearest Neighbor Graph Construction for Generic Similarity Measures"
 * by Dong et al.
 * http://www.cs.princeton.edu/cass/papers/www11.pdf
 * 
 * NN-Descent works by iteratively exploring the neighbors of neighbors...
 * 
 * @author Thibault Debatty
 * @param  The class of nodes value
 */
public class NNDescent extends DistributedGraphBuilder implements Serializable {
    
    private int max_iterations = 10;
    
    /**
     * Set the maximum number of iterations.
     * Default value is 10
     * @param max_iterations
     * @return 
     */
    public NNDescent setMaxIterations(int max_iterations) {
        if (max_iterations <= 0) {
            throw new InvalidParameterException("max_iterations must be positive!");
        }
        
        this.max_iterations = max_iterations;
        return this;
    }

    /**
     *
     * @param nodes
     * @return
     */
    protected JavaPairRDD, NeighborList> _computeGraph(JavaRDD> nodes) {
        
        // Randomize: associate each node to 10 buckets out of 20
        JavaPairRDD> randomized = nodes.flatMapToPair(
                new PairFlatMapFunction, Integer, Node>() {
            
            Random rand = new Random();
            
            public Iterable>> call(Node n) throws Exception {
                ArrayList>> r = new ArrayList>>();
                for (int i = 0; i < 10; i++) {
                    r.add(new Tuple2>(rand.nextInt(20), n));
                }
                
                return r;
            }
        });
        
        // Inside bucket, associate
        JavaPairRDD, NeighborList> random_nl = randomized.groupByKey().flatMapToPair(
                new PairFlatMapFunction>>, Node, NeighborList>() {
                    
            Random rand = new Random();
            public Iterable, NeighborList>> call(
                    Tuple2>> tuple) throws Exception {

                // Read all tuples in bucket
                ArrayList> nodes = new ArrayList>();
                for (Node n : tuple._2) {
                    nodes.add(n);
                }

                ArrayList, NeighborList>> r = new ArrayList, NeighborList>>();
                for (Node n : nodes) {
                    NeighborList nnl = new NeighborList(k);
                    for (int i = 0; i < k; i++) {
                        nnl.add(new Neighbor(
                                nodes.get(rand.nextInt(nodes.size())),
                                Double.MAX_VALUE));
                    }

                    r.add(new Tuple2, NeighborList>(n, nnl));
                }

                return r;
            }
        });
        
        // Merge
        JavaPairRDD, NeighborList> graph = random_nl.reduceByKey(
                
                new Function2() {
            
            public NeighborList call(NeighborList nl1, NeighborList nl2) throws Exception {
                NeighborList nnl = new NeighborList(k);
                nnl.addAll(nl1);
                nnl.addAll(nl2);
                return nnl;
            }
        });
        
        for (int iteration = 0; iteration < max_iterations; iteration++) {
            
            // Reverse
            JavaPairRDD, Node> exploded_graph = graph.flatMapToPair(
                    new PairFlatMapFunction, NeighborList>, Node, Node>() {
                
                public Iterable, Node>> call(Tuple2, NeighborList> tuple) throws Exception {
                    
                    ArrayList, Node>> r = new ArrayList, Node>>();
                    
                    for (Neighbor neighbor : tuple._2()) {

                        r.add(new Tuple2, Node>(tuple._1(), neighbor.node));
                        r.add(new Tuple2, Node>(neighbor.node, tuple._1()));
                    }
                    return r;
                    
                }
            });
            
            
            // 
            graph = exploded_graph.groupByKey().flatMapToPair(
                    new PairFlatMapFunction, Iterable>>, Node, NeighborList>() {

                public Iterable, NeighborList>> call(Tuple2, Iterable>> tuple) throws Exception {
                    
                    // Fetch all nodes
                    ArrayList> nodes = new ArrayList>();
                    nodes.add(tuple._1);
                    
                    for (Node n : tuple._2) {
                        nodes.add(n);
                    }
                    
                    // 
                    ArrayList, NeighborList>> r = new ArrayList, NeighborList>>(nodes.size());
                    
                    for (Node n : nodes) {
                        NeighborList nl = new NeighborList(k);
                        
                        for (Node other : nodes) {
                            if (other.equals(n)) {
                                continue;
                            }
                            
                            nl.add(new Neighbor(
                                    other,
                                    similarity.similarity(n.value, other.value)));
                        }
                        
                        r.add(new Tuple2, NeighborList>(n, nl));
                    }
                    
                    return r;
                    
                }
            });
            
            // Filter
            graph = graph.groupByKey().mapToPair(new PairFunction, Iterable>, Node, NeighborList>() {

                public Tuple2, NeighborList> call(Tuple2, Iterable> tuple) throws Exception {
                    NeighborList nl = new NeighborList(k);
                    
                    for (NeighborList other : tuple._2()) {
                        nl.addAll(other);
                    }
                    
                    return new Tuple2, NeighborList>(tuple._1, nl);
                }
            });
        }
        
        return graph;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy