All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.neighbor.LSH Maven / Gradle / Ivy

The newest version!
/*******************************************************************************
 * Copyright (c) 2010 Haifeng Li
 *   
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *  
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/
package smile.neighbor;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.LinkedHashSet;
import smile.math.IntArrayList;
import smile.math.Math;
import smile.sort.HeapSelect;
import smile.stat.distribution.GaussianDistribution;

/**
 * Locality-Sensitive Hashing. LSH is an efficient algorithm for
 * approximate nearest neighbor search in high dimensional spaces
 * by performing probabilistic dimension reduction of data. The basic idea
 * is to hash the input items so that similar items are mapped to the same
 * buckets with high probability (the number of buckets being much smaller
 * than the universe of possible input items).
 *
 * 

References

*
    *
  1. Alexandr Andoni and Piotr Indyk. Near-Optimal Hashing Algorithms for Near Neighbor Problem in High Dimensions. FOCS, 2006.
  2. *
  3. Alexandr Andoni, Mayur Datar, Nicole Immorlica, Piotr Indyk, and Vahab Mirrokni. Locality-Sensitive Hashing Scheme Based on p-Stable Distributions. 2004.
  4. *
* * @see MPLSH * * @param the type of data objects in the hash table. * * @author Haifeng Li */ public class LSH implements NearestNeighborSearch, KNNSearch, RNNSearch { /** * A bucket is a container for points that all have the same value for hash * function g (function g is a vector of k LSH functions). A bucket is specified by a vector in integers of length k. */ static class BucketEntry { /** * The bucket numbers given by the universal bucket hashing. * These numbers are used instead of the full k-vector (value of the hash * function g) describing the bucket. With a high probability all * buckets will have different pairs of numbers. */ int bucket; /** * The indices of points that all have the same value for hash function g. */ IntArrayList entry; /** * Constructor. * @param bucket the bucket number given by universal hashing. */ BucketEntry(int bucket) { this.bucket = bucket; entry = new IntArrayList(); } /** * Adds a point to bucket. * @param point the index of point. */ void add(int point) { entry.add(point); } /** * Removes a point from bucket. * @param point the index of point. * @return true if the point was in the bucket. */ boolean remove(int point) { int n = entry.size(); for (int i = 0; i < n; i++) { if (entry.get(i) == point) { entry.remove(i); return true; } } return false; } } /** * The entry of an universal hash table with collision solved by chaining. */ static class HashEntry { /** * The chain of buckets in the slot. */ List entry; /** * Constructor. */ HashEntry() { entry = new LinkedList(); } /** * Adds a point to given bucket. * @param bucket the bucket number. * @param point the index of point. */ void add(int bucket, int point) { for (BucketEntry b : entry) { if (b.bucket == bucket) { b.add(point); return; } } BucketEntry b = new BucketEntry(bucket); b.add(point); entry.add(b); } /** * Removes a point from given bucket. * @param bucket the bucket number. * @param point the index of point. * @return true if the point was in the bucket. */ boolean remove(int bucket, int point) { for (BucketEntry b : entry) { if (b.bucket == bucket) { return b.remove(point); } } return false; } } /** * The hash function for data in Euclidean spaces. */ class Hash { /** * The random vectors with entries chosen independently from a Gaussian * distribution. */ double[][] a; /** * Real numbers chosen uniformly from the range [0, w]. */ double[] b; /** * Hash table. */ HashEntry[] table; /** * Constructor. */ Hash() { a = new double[k][d]; b = new double[k]; GaussianDistribution gaussian = GaussianDistribution.getInstance(); for (int i = 0; i < k; i++) { for (int j = 0; j < d; j++) { a[i][j] = gaussian.rand(); } b[i] = Math.random(0, w); } table = new HashEntry[H]; } /** * Returns the hash value of given vector x. * @param x the vector to be hashed. * @param m the m-th hash function to be employed. * @return the hash value. */ int hash(double[] x, int m) { double g = b[m]; for (int j = 0; j < d; j++) { g += a[m][j] * x[j]; } int h = (int) Math.floor(g / w); if (h < 0) { h += 2147483647; } return h; } /** * Apply hash functions on given vector x. * @param r universal hashing random integers. * @param x the vector to be hashed. * @return the bucket of hash table for given vector x. */ int hash(int[] r, double[] x) { long g = 0; for (int i = 0; i < k; i++) { g += r[i] * hash(x, i); } int h = (int) (g % P); if (h < 0) { h += P; } return h; } /** * Insert an item into the hash table. */ void add(int index, double[] x) { int bucket = hash(r2, x); int i = hash(r1, x) % H; if (table[i] == null) { table[i] = new HashEntry(); } table[i].add(bucket, index); } /** * Returns the bucket entry for the given point. */ BucketEntry get(double[] x) { int bucket = hash(r2, x); int i = hash(r1, x) % H; HashEntry he = table[i]; if (he == null) { return null; } for (BucketEntry be : he.entry) { if (bucket == be.bucket) { return be; } } return null; } } /** * The prime number in universal bucket hashing. */ final int P = 2147483647; /** * The range of universal hashing random integers [0, 2^29). */ final int MAX_HASH_RND = 536870912; /** * The keys of data objects. */ ArrayList keys; /** * The data objects. */ ArrayList data; /** * Hash functions. */ List hash; /** * The size of hash table. */ int H; /** * The dimensionality of data. */ int d; /** * The number of hash tables. */ int L; /** * The number of random projections per hash value. */ int k; /** * The width of projection. The hash function is defined as floor((a * x + b) / w). The value * of w determines the bucket interval. */ double w; /** * The random integer used for universal bucket hashing. */ int[] r1; /** * The random integer used for universal hashing for control values. */ int[] r2; /** * Whether to exclude query object self from the neighborhood. */ boolean identicalExcluded = true; /** * Constructor. * @param keys the keys of data objects. * @param data the data objects. */ public LSH(double[][] keys, E[] data) { this(keys, data, 4.0); } /** * Constructor. * @param keys the keys of data objects. * @param data the data objects. * @param w the width of random projections. It should be sufficiently * away from 0. But we should not choose an w value that is too large, which * will increase the query time. */ public LSH(double[][] keys, E[] data, double w) { this(keys, data, w, keys.length); } /** * Constructor. * @param keys the keys of data objects. * @param data the data objects. * @param w the width of random projections. It should be sufficiently * away from 0. But we should not choose an w value that is too large, which * will increase the query time. * @param H the size of universal hash tables. */ public LSH(double[][] keys, E[] data, double w, int H) { this(keys[0].length, Math.max(50, (int) Math.pow(keys.length, 0.25)), Math.max(3, (int) Math.log10(keys.length)), w, H); if (keys.length != data.length) { throw new IllegalArgumentException("The array size of keys and data are different."); } if (H < keys.length) { throw new IllegalArgumentException("Hash table size is too small: " + H); } int n = keys.length; for (int i = 0; i < n; i++) { put(keys[i], data[i]); } } /** * Constructor. * @param d the dimensionality of data. * @param L the number of hash tables. * @param k the number of random projection hash functions, which is usually * set to log(N) where N is the dataset size. */ public LSH(int d, int L, int k) { this(d, L, k, 4.0); } /** * Constructor. * @param d the dimensionality of data. * @param L the number of hash tables. * @param k the number of random projection hash functions, which is usually * set to log(N) where N is the dataset size. * @param w the width of random projections. It should be sufficiently * away from 0. But we should not choose an w value that is too large, which * will increase the query time. */ public LSH(int d, int L, int k, double w) { this(d, L, k, w, 1017881); } /** * Constructor. * @param d the dimensionality of data. * @param L the number of hash tables. * @param k the number of random projection hash functions, which is usually * set to log(N) where N is the dataset size. * @param w the width of random projections. It should be sufficiently * away from 0. But we should not choose an w value that is too large, which * will increase the query time. * @param H the size of universal hash tables. */ public LSH(int d, int L, int k, double w, int H) { if (d < 2) { throw new IllegalArgumentException("Invalid input space dimension: " + d); } if (L < 1) { throw new IllegalArgumentException("Invalid number of hash tables: " + L); } if (k < 1) { throw new IllegalArgumentException("Invalid number of random projections per hash value: " + k); } if (w <= 0.0) { throw new IllegalArgumentException("Invalid width of random projections: " + w); } if (H < 1) { throw new IllegalArgumentException("Invalid size of hash tables: " + H); } this.d = d; this.L = L; this.k = k; this.w = w; this.H = H; keys = new ArrayList(); data = new ArrayList(); r1 = new int[k]; r2 = new int[k]; for (int i = 0; i < k; i++) { r1[i] = Math.randomInt(MAX_HASH_RND); r2[i] = Math.randomInt(MAX_HASH_RND); } hash = new ArrayList(L); for (int i = 0; i < L; i++) { hash.add(new Hash()); } } @Override public String toString() { return String.format("LSH (L=%d, k=%d, H=%d, w=%.4f)", hash.size(), k, H, w); } /** * Get whether if query object self be excluded from the neighborhood. */ public boolean isIdenticalExcluded() { return identicalExcluded; } /** * Set if exclude query object self from the neighborhood. */ public LSH setIdenticalExcluded(boolean excluded) { identicalExcluded = excluded; return this; } /** * Insert an item into the hash table. */ public void put(double[] key, E value) { int index = keys.size(); keys.add(key); data.add(value); for (Hash h : hash) { h.add(index, key); } } @Override public Neighbor nearest(double[] q) { Set candidates = obtainCandidates(q); Neighbor neighbor = new Neighbor(null, null, -1, Double.MAX_VALUE); for (int index : candidates) { double[] key = keys.get(index); if (q == key && identicalExcluded) { continue; } double distance = Math.distance(q, key); if (distance < neighbor.distance) { neighbor.index = index; neighbor.distance = distance; neighbor.key = key; neighbor.value = data.get(index); } } return neighbor; } @Override public Neighbor[] knn(double[] q, int k) { if (k < 1) { throw new IllegalArgumentException("Invalid k: " + k); } Set candidates = obtainCandidates(q); Neighbor neighbor = new Neighbor(null, null, 0, Double.MAX_VALUE); @SuppressWarnings("unchecked") Neighbor[] neighbors = (Neighbor[]) java.lang.reflect.Array.newInstance(neighbor.getClass(), k); HeapSelect> heap = new HeapSelect>(neighbors); for (int i = 0; i < k; i++) { heap.add(neighbor); } int hit = 0; for (int index : candidates) { double[] key = keys.get(index); if (q == key && identicalExcluded) { continue; } double distance = Math.distance(q, key); if (distance < heap.peek().distance) { heap.add(new Neighbor(key, data.get(index), index, distance)); hit++; } } heap.sort(); if (hit < k) { @SuppressWarnings("unchecked") Neighbor[] n2 = (Neighbor[]) java.lang.reflect.Array.newInstance(neighbor.getClass(), hit); int start = k - hit; for (int i = 0; i < hit; i++) { n2[i] = neighbors[i + start]; } neighbors = n2; } return neighbors; } @Override public void range(double[] q, double radius, List> neighbors) { if (radius <= 0.0) { throw new IllegalArgumentException("Invalid radius: " + radius); } Set candidates = obtainCandidates(q); for (int index : candidates) { double[] key = keys.get(index); if (q == key && identicalExcluded) { continue; } double distance = Math.distance(q, key); if (distance <= radius) { neighbors.add(new Neighbor(key, data.get(index), index, distance)); } } } /** * Obtaining Candidates * @return Indices of Candidates */ private Set obtainCandidates(double[] q) { Set candidates = new LinkedHashSet(); for (Hash h : hash) { BucketEntry bucket = h.get(q); if (bucket != null) { int m = bucket.entry.size(); for (int i = 0; i < m; i++) { int index = bucket.entry.get(i); candidates.add(index); } } } return candidates; } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy