All Downloads are FREE. Search and download functionalities are using the official Maven repository.

info.debatty.java.graphs.build.NNDescent Maven / Gradle / Ivy

Go to download

Algorithms that build k-nearest neighbors graph (k-nn graph): Brute-force, NN-Descent,...

There is a newer version: 0.41
Show newest version
package info.debatty.java.graphs.build;

import info.debatty.java.graphs.Graph;
import info.debatty.java.graphs.Neighbor;
import info.debatty.java.graphs.NeighborList;
import info.debatty.java.graphs.Node;
import java.security.InvalidParameterException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Random;

/**
 * Implementation of NN-Descent k-nn graph building algorithm.
 * Based on the paper "Efficient K-Nearest Neighbor Graph Construction for 
 * Generic Similarity Measures" by Dong et al.
 * http://www.cs.princeton.edu/cass/papers/www11.pdf
 * 
 * NN-Descent works by iteratively exploring the neighbors of neighbors...
 * It is not suitable for small datasets (less than 500 items)!
 * @author Thibault Debatty
 * @param  The type of nodes value
 */

public class NNDescent extends GraphBuilder {

    protected double rho = 0.5; // Standard : 1, Fast: 0.5
    protected double delta = 0.001;
    protected int max_iterations = Integer.MAX_VALUE;
    
    protected int iterations = 0;
    protected int c;
    
    
    /**
     * Get the number of edges modified at the last iteration
     * @return 
     */
    public int getC() {
        return c;
    }
    
    
    /**
     * Get the number of executed iterations
     * @return 
     */
    public int getIterations() {
        return iterations;
    }

    public double getRho() {
        return rho;
    }

    /**
     * Sampling coefficient.
     * In interval ]0, 1.0]
     * Typical value for fast computation is 0.5
     * Use 1.0 for precise computation
     * Default is 0.5
     * @param rho 
     */
    public void setRho(double rho) {
        if (rho > 1.0 || rho <= 0.0) {
            throw new InvalidParameterException("0 < rho <= 1.0");
        }
        this.rho = rho;
    }

    public double getDelta() {
        return delta;
    }

    /**
     * Early termination coefficient.
     * The algorithm stops when less than this proportion of edges are modified
     * Should be in ]0, 1.0[
     * Default is 0.001
     * @param delta 
     */
    public void setDelta(double delta) {
        if (rho >= 1.0 || rho <= 0.0) {
            throw new InvalidParameterException("0 < delta < 1.0");
        }
        this.delta = delta;
    }

    public int getMaxIterations() {
        return max_iterations;
    }

    /**
     * Set the maximum number of iterations
     * Default is no max (Integer.MAX_VALUE)
     * @param max_iterations 
     */
    public void setMaxIterations(int max_iterations) {
        if (max_iterations < 0) {
            throw new InvalidParameterException("max_iterations should be positive!");
        }
        this.max_iterations = max_iterations;
    }

    
    @Override
    protected Graph _computeGraph(List> nodes) {
        
        iterations = 0;
        
        if (nodes.size() <= (k+1)) {
            return MakeFullyLinked(nodes);
        }
        
        Graph neighborlists = new Graph(nodes.size());
        HashMap, ArrayList> old_lists, new_lists, old_lists_2, new_lists_2;
        
        old_lists = new HashMap, ArrayList>(nodes.size());
        new_lists = new HashMap, ArrayList>(nodes.size());
        
        HashMap data = new HashMap();
    
        // B[v]←− Sample(V,K)×{?∞, true?} ∀v ∈ V
        // For each node, create a random neighborlist
        for (Node v : nodes) {
            neighborlists.put(v, RandomNeighborList(nodes, v));
        }

        // loop
        while (true) {
            iterations++;
            c = 0;
            
            // for v ∈ V do
            // old[v]←− all items in B[v] with a false flag
            // new[v]←− ρK items in B[v] with a true flag
            // Mark sampled items in B[v] as false;
            for (int i = 0; i < nodes.size(); i++) {
                Node v = nodes.get(i);
                old_lists.put(v, PickFalses(neighborlists.get(v)));
                new_lists.put(v, PickTruesAndMark(neighborlists.get(v)));

            }

            // old′ ←Reverse(old)
            // new′ ←Reverse(new)
            old_lists_2 = Reverse(nodes, old_lists);
            new_lists_2 = Reverse(nodes, new_lists);

            // for v ∈ V do
            for (int i = 0; i < nodes.size(); i++) {
                Node v = nodes.get(i);
                // old[v]←− old[v] ∪ Sample(old′[v], ρK)
                // new[v]←− new[v] ∪ Sample(new′[v], ρK)
                old_lists.put(v, Union(old_lists.get(v), Sample(old_lists_2.get(v), (int) (rho * k))));
                new_lists.put(v, Union(new_lists.get(v), Sample(new_lists_2.get(v), (int) (rho * k))));

                // for u1,u2 ∈ new[v], u1 < u2 do
                for (int j = 0; j < new_lists.get(v).size(); j++) {
                    Node u1 = (Node) new_lists.get(v).get(j);

                        //int u1_i = Find(u1); // position of u1 in nodes
                    for (int l = j + 1; l < new_lists.get(u1).size(); l++) {
                        Node u2 = (Node) new_lists.get(u1).get(l);
                            //int u2_i = Find(u2);

                            // l←− σ(u1,u2)
                        // c←− c+UpdateNN(B[u1], u2, l, true)
                        // c←− c+UpdateNN(B[u2], u1, l, true)
                        double s = Similarity(u1, u2);
                        c += UpdateNL(neighborlists.get(u1), u2, s);
                        c += UpdateNL(neighborlists.get(u2), u1, s);
                    }

                    // or u1 ∈ new[v], u2 ∈ old[v] do
                    for (int l = 0; l < old_lists.get(v).size(); l++) {
                        Node u2 = (Node) old_lists.get(v).get(l);

                        if (u1.equals(u2)) {
                            continue;
                        }

                        //int u2_i = Find(u2);
                        double s = Similarity(u1, u2);
                        c += UpdateNL(neighborlists.get(u1), u2, s);
                        c += UpdateNL(neighborlists.get(u2), u1, s);
                    }
                }
            }
            
            //System.out.println("C : " + c);
            if (callback != null) {
                data.put("c", c);
                data.put("computed_similarities", computed_similarities);
                data.put("computed_similarities_ratio",
                        (double) computed_similarities / (nodes.size() * (nodes.size() - 1) / 2));
                data.put("iterations", iterations);
                
                callback.call(data);
            }

            if (c <= (delta * nodes.size() * k)) {
                break;
            }
            
            if (iterations >= max_iterations) {
                break;
            }
        }
        
        return neighborlists;
    }

    protected ArrayList Union(ArrayList l1, ArrayList l2) {
        ArrayList r = new ArrayList();
        for (Node n : l1) {
            if (!r.contains(n)) {
                r.add(n);
            }
        }
        
        for (Node n : l2) {
            if (!r.contains(n)) {
                r.add(n);
            }
        }

        return r;
    }

    protected NeighborList RandomNeighborList(List> nodes, Node for_node) {
        //System.out.println("Random NL for node " + for_node);
        NeighborList nl = new NeighborList(k);
        Random r = new Random();

        while (nl.size() < k) {
            Node node = nodes.get(r.nextInt(nodes.size()));
            if (! node.equals(for_node)) {
                double s = Similarity(node, for_node);
                nl.add(new Neighbor(node, s));
            }
        }
        
        return nl;
    }

    protected ArrayList PickFalses(NeighborList neighborList) {
        ArrayList falses = new ArrayList();
        for (Neighbor n : neighborList) {
            if (!n.is_new) {
                falses.add(n.node);
            }
        }

        return falses;
    }
    
    /**
     * pick new neighbors with a probability of rho, and mark them as false
     *
     * @param neighborList
     * @return
     */
    protected ArrayList PickTruesAndMark(NeighborList neighborList) {
        ArrayList r = new ArrayList();
        for (Neighbor n : neighborList) {
            if (n.is_new && Math.random() < rho) {
                n.is_new = false;
                r.add(n.node);
            }
        }

        return r;
    }


    protected HashMap, ArrayList> Reverse(List> nodes, HashMap, ArrayList> lists) {

        HashMap, ArrayList> R = new HashMap, ArrayList>(nodes.size());
        
        // Create all arraylists
        for (Node n : nodes) {
            R.put(n, new ArrayList());
        }

        // For each node and corresponding arraylist
        for (Node node : nodes) {
            ArrayList list = lists.get(node);
            for (Node other_node : list) {
                R.get(other_node).add(node);
            }
        }

        return R;
    }

    /**
     * Reverse NN array R[v] is the list of elements (u) for which v is a
     * neighbor (v is in B[u])
     *
     * @param nodes
     * @param count
     * @return 
     */

    protected ArrayList Sample(ArrayList nodes, int count) {
        Random r = new Random();
        while (nodes.size() > count) {
            nodes.remove(r.nextInt(nodes.size()));
        }

        return nodes;

    }

    protected int UpdateNL(NeighborList nl, Node n, double similarity) {
        Neighbor neighbor = new Neighbor(n, similarity);
        return nl.add(neighbor) ? 1 : 0;
    }


    protected double Similarity(Node n1, Node n2) {
        computed_similarities++;
        return similarity.similarity((T) n1.value, (T) n2.value);
        
    }

    protected Graph MakeFullyLinked(List> nodes) {
        Graph neighborlists = new Graph(nodes.size());
        for (Node node : nodes) {
            NeighborList neighborlist = new NeighborList(k);
            for (Node other_node : nodes) {
                if (node.equals(other_node)) {
                    continue;
                }
                
                neighborlist.add(new Neighbor(
                        other_node,
                        Similarity(node, other_node)
                ));
            }
            neighborlists.put(node, neighborlist);
        }
        
        return neighborlists;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy