All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.graph.NearestNeighborGraph Maven / Gradle / Ivy

There is a newer version: 4.2.0
Show newest version
/*
 * Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */

package smile.graph;

import java.util.*;
import java.util.stream.IntStream;
import smile.math.MathEx;
import smile.math.distance.Distance;
import smile.math.distance.Metric;
import smile.neighbor.RandomProjectionTree;

/**
 * The k-nearest neighbor graph builder.
 *
 * @param k k-nearest neighbor.
 * @param neighbors The indices of k-nearest neighbors.
 * @param distances The distances to k-nearest neighbors.
 * @param index The sample index of each vertex in original dataset.
 * @author Haifeng Li
 */
public record NearestNeighborGraph(int k, int[][] neighbors, double[][] distances, int[] index) {
    private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(NearestNeighborGraph.class);

    /**
     * Constructor.
     * @param k k-nearest neighbor.
     * @param neighbors The indices of k-nearest neighbors.
     * @param distances The distances to k-nearest neighbors.
     */
    public NearestNeighborGraph(int k, int[][] neighbors, double[][] distances) {
        this(k, neighbors, distances, IntStream.range(0, neighbors.length).toArray());
    }

    /**
     * Returns the number of vertices.
     * @return the number of vertices.
     */
    public int size() {
        return neighbors.length;
    }

    /**
     * Returns the nearest neighbor graph.
     * @param digraph create a directed graph if true.
     * @return the nearest neighbor graph.
     */
    public AdjacencyList graph(boolean digraph) {
        int n = neighbors.length;
        AdjacencyList graph = new AdjacencyList(n, digraph);
        IntStream.range(0, n).forEach(i -> {
            int[] neighbor = neighbors[i];
            double[] distance = distances[i];
            for (int j = 0; j < neighbor.length; j++) {
                graph.setWeight(i, neighbor[j], distance[j]);
            }
        });
        return graph;
    }

    /**
     * Creates a nearest neighbor graph with Euclidean distance.
     *
     * @param data the dataset.
     * @param k k-nearest neighbor.
     * @return k-nearest neighbor graph.
     */
    public static NearestNeighborGraph of(double[][] data, int k) {
        return of(data, MathEx::distance, k);
    }

    /**
     * Returns the largest connected component of a nearest neighbor graph.
     *
     * @param digraph create a directed graph if true.
     * @return the largest connected component.
     */
    public NearestNeighborGraph largest(boolean digraph) {
        AdjacencyList graph = graph(digraph);
        int[][] cc = graph.bfcc();
        if (cc.length == 1) {
            return this;
        } else {
            int[] index = Arrays.stream(cc)
                    .max(Comparator.comparing(a -> a.length))
                    .orElseThrow(NoSuchElementException::new);
            logger.info("{} connected components, largest one has {} samples.", cc.length, index.length);

            int n = neighbors.length;
            int[] reverseIndex = new int[n];
            for (int i = 0; i < n; i++) {
                reverseIndex[index[i]] = i;
            }

            int[][] nearest = new int[n][k];
            double[][] dist = new double[n][k];
            for (int i = 0; i < n; i++) {
                dist[i] = distances[index[i]];
                int[] ni = neighbors[index[i]];
                for (int j = 0; j < k; j++) {
                    nearest[i][j] = reverseIndex[ni[j]];
                }
            }
            return new NearestNeighborGraph(k, nearest, dist, index);
        }
    }

    /**
     * Creates a nearest neighbor graph.
     *
     * @param data the dataset.
     * @param k k-nearest neighbor.
     * @param distance the distance function.
     * @return k-nearest neighbor graph.
     */
    public static  NearestNeighborGraph of(T[] data, Distance distance, int k) {
        var heap = build(data, distance, k, (int n, int k_, int i) -> IntStream.range(0, n).toArray());
        return toGraph(heap, k);
    }

    /**
     * Creates a random neighbor graph.
     *
     * @param data the dataset.
     * @param k k-random neighbor.
     * @param distance the distance function.
     * @return k-random neighbor graph.
     */
    public static  NearestNeighborGraph random(T[] data, Distance distance, int k) {
        var heap = build(data, distance, k, NearestNeighborGraph::rejectionSample);
        extend(heap);
        return toGraph(heap, k);
    }

    private static class Neighbor implements Comparable {
        public int index;
        public double distance;

        public Neighbor(int index, double distance) {
            this.index = index;
            this.distance = distance;
        }

        @Override
        public int hashCode() {
            return index;
        }

        @Override
        public int compareTo(Neighbor o) {
            return Double.compare(o.distance, distance);
        }
    }

    /**
     * Creates a nearest neighbor graph.
     *
     * @param data the dataset.
     * @param k k-nearest neighbor.
     * @param distance the distance function.
     * @param candidates neighbor candidate generator.
     * @return a list of k-nearest neighbor heaps for each data point.
     */
    private static  List> build(T[] data, Distance distance, int k, CandidateGenerator candidates) {
        if (k < 2) {
            throw new IllegalArgumentException("k must be greater than 1: " + k);
        }

        int n = data.length;
        List> heap = new ArrayList<>(n);
        for (int i = 0; i < n; i++) {
            heap.add(new PriorityQueue<>());
        }

        IntStream.range(0, n).parallel().forEach(i -> {
            T xi = data[i];
            PriorityQueue pq = heap.get(i);
            for (int j : candidates.generate(n, k, i)) {
                if (j == i) continue;
                double dist = distance.d(xi, data[j]);
                if (pq.size() < k) {
                    pq.offer(new Neighbor(j, dist));
                } else if (dist < pq.peek().distance) {
                    Neighbor neighbor = pq.poll();
                    neighbor.index = j;
                    neighbor.distance = dist;
                    pq.offer(neighbor);
                }
            }
        });

        return heap;
    }

    /** Extends nearest neighbor heap with reverse nearest neighbors. */
    private static void extend(List> heap) {
        int n = heap.size();
        List> neighbors = new ArrayList<>(n);
        List> reverseNeighbors = new ArrayList<>(n);
        for (int i = 0; i < n; i++) {
            neighbors.add(new HashSet<>());
            reverseNeighbors.add(new HashSet<>());
        }

        for (int i = 0; i < n; i++) {
            Set set = neighbors.get(i);
            PriorityQueue pq = heap.get(i);
            for (var neighbor : pq) {
                set.add(neighbor.index);
                reverseNeighbors.get(neighbor.index).add(new Neighbor(i, neighbor.distance));
            }
        }

        for (int i = 0; i < n; i++) {
            Set set = neighbors.get(i);
            PriorityQueue pq = heap.get(i);
            for (var neighbor : reverseNeighbors.get(i)) {
                if (!set.contains(neighbor.index)) {
                    if (neighbor.distance < pq.peek().distance) {
                        Neighbor top = pq.poll();
                        top.index = neighbor.index;
                        top.distance = neighbor.distance;
                        pq.offer(top);
                    }
                }
            }
        }
    }

    /** Returns a near neighbor graph with heaps. */
    private static NearestNeighborGraph toGraph(List> heap, int k) {
        int n = heap.size();
        int[][] neighbors = new int[n][k];
        double[][] distances = new double[n][k];
        for (int i = 0; i < n; i++) {
            PriorityQueue pq = heap.get(i);
            int j = pq.size();
            while (!pq.isEmpty()) {
                Neighbor neighbor = pq.poll();
                if (--j < k) {
                    neighbors[i][j] = neighbor.index;
                    distances[i][j] = neighbor.distance;
                }
            }
        }

        return new NearestNeighborGraph(k, neighbors, distances);
    }

    private interface CandidateGenerator {
        int[] generate(int n, int k, int i);
    }

    /**
     * Creates an approximate nearest neighbor graph with random projection
     * forest and Euclidean distance.
     *
     * @param data the dataset.
     * @param k k-nearest neighbor.
     * @return approximate k-nearest neighbor graph.
     */
    public static NearestNeighborGraph descent(double[][] data, int k) {
        return descent(data, k, 5, k, 50, 50, 0.001);
    }

    /**
     * Creates an approximate nearest neighbor graph with random projection
     * forest and Euclidean distance.
     *
     * @param data the dataset.
     * @param k k-nearest neighbor.
     * @param numTrees the number of trees.
     * @param leafSize The maximum size of leaf node.
     * @return approximate k-nearest neighbor graph.
     */
    public static NearestNeighborGraph descent(double[][] data, int k, int numTrees, int leafSize,
                                               int maxCandidates, int maxIter, double delta) {
        int n = data.length;
        List> heapList = new ArrayList<>(data.length);
        List> neighborSetList = new ArrayList<>(data.length);
        for (int i = 0; i < data.length; i++) {
            heapList.add(new PriorityQueue<>());
            neighborSetList.add(new HashSet<>());
        }

        for (int ti = 0; ti < numTrees; ti++) {
            RandomProjectionTree tree = RandomProjectionTree.of(data, leafSize, false);
            for (int[] leaf : tree.leafSamples()) {
                for (int li = 0; li < leaf.length; li++) {
                    int i = leaf[li];
                    double[] xi = data[i];
                    for (int lj = li + 1; lj < leaf.length; lj++) {
                        int j = leaf[lj];
                        double[] xj = data[j];
                        double dist = MathEx.distance(xi, xj);

                        updateHeap(heapList.get(i), neighborSetList.get(i), k, j, dist);
                        updateHeap(heapList.get(j), neighborSetList.get(j), k, i, dist);
                    }
                }
            }
        }

        return descent(data, MathEx::distance, heapList, k, maxCandidates, maxIter, delta);
    }

    private static boolean updateHeap(PriorityQueue pq, Set set, int k, int index, double dist) {
        if (!set.contains(index)) {
            if (pq.size() < k) {
                pq.add(new Neighbor(index, dist));
                set.add(index);
                return true;
            } else {
                if (dist < pq.peek().distance) {
                    var top = pq.poll();
                    set.remove(top.index);
                    set.add(index);
                    top.distance = dist;
                    top.index = index;
                    pq.offer(top);
                    return true;
                }
            }
        }
        return false;
    }

    /**
     * Creates an approximate nearest neighbor graph with the NN-Descent algorithm.
     *
     * @param data the dataset.
     * @param k k-nearest neighbor.
     * @param distance the distance function.
     * @return approximate k-nearest neighbor graph.
     */
    public static  NearestNeighborGraph descent(T[] data, Metric distance, int k) {
        return descent(data, distance, k, 50, 10, 0.001);
    }

    /**
     * Creates an approximate nearest neighbor graph with the NN-Descent algorithm.
     *
     * @param data the dataset.
     * @param k k-nearest neighbor.
     * @param distance the distance function.
     * @param maxCandidates the maximum number of candidates in nearest neighbor search.
     * @param maxIter the maximum number of iterations.
     * @param delta Controls the early stop due to limited progress. Larger values
     *         will result in earlier aborts, providing less accurate indexes,
     *         and less accurate searching.
     * @return approximate k-nearest neighbor graph.
     */
    public static  NearestNeighborGraph descent(T[] data, Metric distance, int k, int maxCandidates, int maxIter, double delta) {
        if (k < 2) {
            throw new IllegalArgumentException("k must be greater than 1: " + k);
        }
        var heap = build(data, distance, k, NearestNeighborGraph::rejectionSample);
        extend(heap);
        return descent(data, distance, heap, k, maxCandidates, maxIter, delta);
    }

    private static  NearestNeighborGraph descent(T[] data, Metric distance, List> heapList,
                                                    int k, int maxCandidates, int maxIter, double delta) {
        int n = data.length;
        List> neighborSetList = new ArrayList<>(data.length);
        for (int i = 0; i < data.length; i++) {
            neighborSetList.add(new HashSet<>());
        }
        for (int i = 0; i < n; i++) {
            var set = neighborSetList.get(i);
            for (var neighbor : heapList.get(i)) {
                set.add(neighbor.index);
            }
        }

        for (int iter = 1; iter <= maxIter; iter++) {
            int count = 0;
            var candidates = generateCandidates(heapList, maxCandidates);
            for (int i = 0; i < n; i++) {
                for (var j : candidates[i]) {
                    double dist = distance.d(data[i], data[j]);
                    if (updateHeap(heapList.get(i), neighborSetList.get(i), k, j, dist)) {
                        ++count;
                    }

                    if (updateHeap(heapList.get(j), neighborSetList.get(j), k, i, dist)) {
                        ++count;
                    }
                }
            }

            logger.info("NearestNeighborDescent iteration {}: {}", iter, count);
            if (count <= delta * k * n) {
                break;
            }
        }

        return toGraph(heapList, k);
    }

    private static int[][] generateCandidates(List> heapList, int maxCandidates) {
        int n = heapList.size();
        List> candidates = new ArrayList<>(n);
        for (int i = 0; i < n; i++) {
            candidates.add(new HashSet<>());
        }

        for (int i = 0; i < n; i++) {
            var pqi = heapList.get(i);
            for (var ni : pqi) {
                int j = ni.index;
                double dij = ni.distance;

                var pqj = heapList.get(j);
                for (var nj : pqj) {
                    int k = nj.index;
                    double djk = nj.distance;
                    candidates.get(i).add(new Neighbor(k, dij + djk));
                    candidates.get(k).add(new Neighbor(i, dij + djk));
                }
            }
        }

        int[][] result = new int[n][];
        for (int i = 0; i < n; i++) {
            List list = new ArrayList<>(candidates.get(i));
            list.sort(Comparator.comparingDouble(o -> o.distance));
            result[i] = list.stream().limit(maxCandidates).mapToInt(neighbor -> neighbor.index).toArray();
        }
        return result;
    }

    /**
     * Generate k integers from 0 to n such that no integer is selected twice.
     * @param n The upper bound of samples.
     * @param k The number of random samples.
     * @param i samples should not equal i.
     * @return random samples.
     */
    private static int[] rejectionSample(int n, int k, int i) {
        if (k > n) {
            throw new IllegalArgumentException();
        }

        int[] samples = new int[k];
        for (int j = 0; j < k; j++) {
            boolean loop = true;
            while (loop) {
                loop = false;
                samples[j] = MathEx.randomInt(n);
                if (samples[j] == i) {
                    loop = true;
                } else {
                    for (int l = 0; l < j; l++) {
                        if (samples[j] == samples[l]) {
                            loop = true;
                            break;
                        }
                    }
                }
            }
        }

        return samples;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy