All Downloads are FREE. Search and download functionalities are using the official Maven repository.

info.debatty.java.graphs.Graph Maven / Gradle / Ivy

Go to download

Algorithms that build k-nearest neighbors graph (k-nn graph): Brute-force, NN-Descent,...

There is a newer version: 0.41
Show newest version
/*
 * The MIT License
 *
 * Copyright 2015 Thibault Debatty.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

package info.debatty.java.graphs;

import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import java.util.Stack;

/**
 * k-nn graph, represented as a mapping node => neighborlist
 * @author Thibault Debatty
 * @param  The type of nodes value
 */
public class Graph extends HashMap, NeighborList> {
            
    public Graph(int n) {
        super(n);
    }
    
    public Graph() {
        super();
    }
    
    /**
     * Get the neighborlist of this node
     * @param node
     * @return the neighborlist of this node 
     */
    public NeighborList get(Node node) {
        return super.get(node);
    }
    
    /**
     * Remove from the graph all edges with a similarity lower than threshold
     * @param threshold 
     */
    public void prune(double threshold) {
        for (NeighborList nl : this.values()) {
            
            // We cannot remove inside the loop
            // => do it in 2 steps:
            ArrayList to_remove = new ArrayList();
            for (Neighbor n : nl) {
                if (n.similarity < threshold) {
                    to_remove.add(n);
                }
            }
            
            nl.removeAll(to_remove);
        }
    }
    
    /**
     * Split the graph in connected components (usually you will first prune the
     * graph to remove "weak" edges).
     * @return 
     */
    public ArrayList> connectedComponents() {
        ArrayList> subgraphs = new ArrayList>();
        ArrayList> nodes_to_process = new ArrayList>(this.keySet());
        
        for (int i = 0; i < nodes_to_process.size(); i++) {
            Node n = nodes_to_process.get(i);
            if (n == null) {
                continue;
            }
            Graph subgraph = new Graph();
            subgraphs.add(subgraph);
            
            addAndFollow(subgraph, n, nodes_to_process);
        }
        
        return subgraphs;
    }
    
    private void addAndFollow(Graph subgraph, Node node, ArrayList> nodes_to_process) {
        nodes_to_process.remove(node);
        
        NeighborList neighborlist = this.get(node);
        subgraph.put(node, neighborlist);
        
        if (neighborlist == null) {
            return;
        }
        
        for (Neighbor neighbor : this.get(node)) {
            if (! subgraph.containsKey(neighbor.node)) {
                addAndFollow(subgraph, neighbor.node, nodes_to_process);
            }
        }
    }
    
    /**
     * Computes the strongly connected sub-graphs (where every node is reachable 
     * from every other node) using Tarjan's algorithm, which has computation
     * cost O(n).
     * @return 
     */
    public ArrayList> stronglyConnectedComponents() {
        Stack stack = new Stack();
        Index index = new Index();
        HashMap bookkeeping = new HashMap(this.size());
        
        ArrayList> connected_components = new ArrayList>();
        
        for (Node n : this.keySet()) {
            
            if (bookkeeping.containsKey(n)) {
                // This node was already processed...
                continue;
            }
            
            ArrayList connected_component = this.strongConnect(n, stack, index, bookkeeping);
            
            if (connected_component == null) {
                continue;
            }
            
            // We found a connected component
            Graph subgraph = new Graph(connected_component.size());
            for (Node node : connected_component) {
                subgraph.put(node, this.get(node));
            }
            connected_components.add(subgraph);
            
        }
        
        return connected_components;
    }

    private ArrayList strongConnect(Node v, Stack stack, Index index, HashMap bookkeeping) {
        bookkeeping.put(v, new NodeProperty(index.Value(), index.Value()));
        index.Inc();
        stack.add(v);
        
        
        for (Neighbor neighbor : this.get(v)) {
            Node w = neighbor.node;
            
            if (! this.containsKey(w) || this.get(w) == null) {
                continue;
            }
            
            
            if (! bookkeeping.containsKey(w)) {
                strongConnect(w, stack, index, bookkeeping);
                bookkeeping.get(v).lowlink = Math.min(
                        bookkeeping.get(v).lowlink,
                        bookkeeping.get(w).lowlink);
                
            } else if(bookkeeping.get(neighbor.node).onstack) {
                bookkeeping.get(v).lowlink = Math.min(
                        bookkeeping.get(v).lowlink,
                        bookkeeping.get(w).index);
                
            }
        }
        
        if (bookkeeping.get(v).lowlink == bookkeeping.get(v).index) {
            ArrayList connected_component = new ArrayList();
            
            Node w;
            do {
                 w = stack.pop();
                bookkeeping.get(w).onstack = false;
                connected_component.add(w);
            } while (v != w);
            
            return connected_component;
        }
        
        return null;
    }
    
    private class Index {
        private int value;
        
        public int Value() {
            return this.value;
        }
        
        public void Inc() {
            this.value++;
        }
    }
    
    private class NodeProperty {

        public int index;
        public int lowlink;
        public boolean onstack;
        
        public NodeProperty(int index, int lowlink) {
            this.index = index;
            this.lowlink = lowlink;
            this.onstack = true;
        }
    };
    
    /**
     * Perform approximate k-nn search on this graph.
     * 
     * @param query
     * @param K search K neighbors
     * @param similarity_measure
     * @param max_similarities perform max max_similarities similarity computations
     * @return
     */
    public NeighborList search(
            T query, 
            int K,
            SimilarityInterface similarity_measure,
            int max_similarities) {
        
        return this.search(
                query,
                K,
                similarity_measure,
                max_similarities,
                100, // default depth value
                1.01); // default expansion value
                
    }
    
    public NeighborList search(
            Node query, 
            int K, 
            SimilarityInterface similarity_measure,
            int max_similarities,
            int search_depth,
            double expansion) {
        
        return search(query.value, K, similarity_measure, max_similarities, search_depth, expansion);
    }
    
    /**
     * Implementation of Graph Nearest Neighbor Search (GNNS) algorithm 
     * from paper "Fast Approximate Nearest-Neighbor Search with k-Nearest 
     * Neighbor Graph" by Hajebi et al.
     * 
     * The algorithm is basically a best-first search method with random 
     * starting points.
     * 
     * @param query query point
     * @param K number of neighbors to find (the K from K-nn search)
     * @param max_similarities max similarities to compute
     * @param search_depth number of greedy steps (default: 100)
     * @param similarity_measure similarity measure 
     * @param expansion (default: 1.01)
     * 
     * @return
     */
    public NeighborList search(
            T query, 
            int K, 
            SimilarityInterface similarity_measure,
            int max_similarities,
            int search_depth,
            double expansion) {
        
        if (K >= this.size()) {
            // Looking for more nodes than this graph contains...
            NeighborList nl = new NeighborList(K);
            for (Node node : this.keySet()) {
                nl.add(
                        new Neighbor(
                                node,
                                similarity_measure.similarity(
                                        query,
                                        node.value)));
            }
            return nl;
        }
        
        // Node => Similarity with query node
        HashMap, Double> visited_nodes = new HashMap, Double>();
        int computed_similarities = 0;
        double global_highest_similarity = 0;
        ArrayList> nodes = new ArrayList>(this.keySet());
        Random rand = new Random();
        
        while (computed_similarities < max_similarities) {
            
            // Select a random node from the graph
            Node current_node = nodes.get(rand.nextInt(nodes.size()));
            
            if (visited_nodes.containsKey(current_node)) {
                continue;
            }
            
            // Skip this starting point if too far / similarity too small!
            double start_similarity = similarity_measure.similarity(
                    query,
                    current_node.value);
            computed_similarities++;
            if (start_similarity < global_highest_similarity / expansion) {
                continue;
            }
            
            for (int step = 0; step < search_depth; step++) {
                NeighborList nl = this.get(current_node);
                
                // Node has no neighbor, continue...
                if (nl == null) {
                    continue;
                }
                
                Iterator Y_nl_iterator = nl.iterator();
                Node most_similar_node = null;
                double highest_similarity = -1;
                
                // From current_node, check all neighbors
                while (Y_nl_iterator.hasNext()) {
                    
                    Node other_node = Y_nl_iterator.next().node;
                    
                    if (visited_nodes.containsKey(other_node)) {
                        continue;
                    }
                    
                    // Compute similarity to query
                    double similarity = similarity_measure.similarity(
                            query,
                            other_node.value);
                    computed_similarities++;
                    visited_nodes.put(other_node, similarity);
                    
                    // Keep the most similar neighbor to the query
                    if (similarity > highest_similarity) {
                        most_similar_node = other_node;
                        highest_similarity = similarity;
                        
                        if (similarity > global_highest_similarity) {
                            global_highest_similarity = similarity;
                        }
                    }
                }
                
                current_node = most_similar_node;
            }
        }
        
        NeighborList neighborList = new NeighborList(K);
        for (Map.Entry, Double> entry : visited_nodes.entrySet()) {
            neighborList.add(new Neighbor(entry.getKey(), entry.getValue()));
        }
        return neighborList;
    }
    
    /**
     * Writes the graph as a GEXF file (to be used in Gephi, for example)
     * @param filename
     * @throws FileNotFoundException
     * @throws IOException 
     */
    public void writeGEXF(String filename) throws FileNotFoundException, IOException {
        Writer out = new OutputStreamWriter(new FileOutputStream(filename));
        out.write(this.gexf_header());
        
        // Write nodes
        out.write("\n");
        for (Node node : this.keySet()) {
            out.write("\n");
        }
        out.write("\n");
            
        // Write edges
        out.write("\n");
        int i = 0;
        for (Node source : this.keySet()) {
            for (Neighbor target : this.get(source)) {
                out.write("\n");
                i++;
            }
        }
            
        out.write("");
                    
        // End the file
        out.write("\n" +
                "");
        out.close();
    }
    
    private String gexf_header() {
        return "\n" +
            "\n" +
            "\n" +
            "info.debatty.java.graphs.Graph\n" +
            "\n" +
            "\n" +
            "\n";
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy