org.openimaj.knn.pq.IncrementalIntADCNearestNeighbours Maven / Gradle / Ivy
/*
AUTOMATICALLY GENERATED BY jTemp FROM
/Users/jsh2/Work/openimaj/target/checkout/machine-learning/nearest-neighbour/src/main/jtemp/org/openimaj/knn/pq/Incremental#T#ADCNearestNeighbours.jtemp
*/
/**
* Copyright (c) 2011, The University of Southampton and the individual contributors.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* * Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* * Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* * Neither the name of the University of Southampton nor the names of its
* contributors may be used to endorse or promote products derived from this
* software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
* ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package org.openimaj.knn.pq;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.openimaj.citation.annotation.Reference;
import org.openimaj.citation.annotation.ReferenceType;
import org.openimaj.data.DataSource;
import org.openimaj.io.IOUtils;
import org.openimaj.io.ReadWriteableBinary;
import org.openimaj.knn.IntNearestNeighbours;
import org.openimaj.knn.IncrementalNearestNeighbours;
import org.openimaj.util.pair.IntFloatPair;
import org.openimaj.util.queue.BoundedPriorityQueue;
/**
* Incremental Nearest-neighbours using Asymmetric Distance Computation (ADC)
* on Product Quantised vectors. In ADC, only the database points are quantised.
* The queries themselves are not quantised. The overall distance is computed
* as the summed distance of each subvector of the query to each corresponding
* centroids of each database vector.
*
* For efficiency, the distance of each sub-vector of a query is computed to
* every centroid (for the sub-vector under consideration) only once, and is
* then cached for the lookup during the computation of the distance to each
* database vector.
*
* @author Jonathon Hare ([email protected])
*/
@Reference(
type = ReferenceType.Article,
author = { "Jegou, Herve", "Douze, Matthijs", "Schmid, Cordelia" },
title = "Product Quantization for Nearest Neighbor Search",
year = "2011",
journal = "IEEE Trans. Pattern Anal. Mach. Intell.",
pages = { "117", "", "128" },
url = "http://dx.doi.org/10.1109/TPAMI.2010.57",
month = "January",
number = "1",
publisher = "IEEE Computer Society",
volume = "33",
customData = {
"issn", "0162-8828",
"numpages", "12",
"doi", "10.1109/TPAMI.2010.57",
"acmid", "1916695",
"address", "Washington, DC, USA",
"keywords", "High-dimensional indexing, High-dimensional indexing, image indexing, very large databases, approximate search., approximate search., image indexing, very large databases"
})
public class IncrementalIntADCNearestNeighbours
extends
IntNearestNeighbours
implements
IncrementalNearestNeighbours,
ReadWriteableBinary
{
protected IntProductQuantiser pq;
protected int ndims;
protected List data;
protected IncrementalIntADCNearestNeighbours() {
//for deserialization
}
/**
* Construct the ADC with the given quantiser and data points.
*
* @param pq
* the Product Quantiser
* @param dataPoints
* the data points to index
*/
public IncrementalIntADCNearestNeighbours(IntProductQuantiser pq, int[][] dataPoints) {
this.pq = pq;
this.ndims = dataPoints[0].length;
this.data = new ArrayList(dataPoints.length);
for (int i = 0; i < dataPoints.length; i++) {
data.add(pq.quantise(dataPoints[i]));
}
}
/**
* Construct the ADC with the given quantiser and data points.
*
* @param pq
* the Product Quantiser
* @param dataPoints
* the data points to index
*/
public IncrementalIntADCNearestNeighbours(IntProductQuantiser pq, List dataPoints) {
this.pq = pq;
this.ndims = dataPoints.get(0).length;
final int size = dataPoints.size();
this.data = new ArrayList(size);
for (int i = 0; i < size; i++) {
data.add(pq.quantise(dataPoints.get(i)));
}
}
/**
* Construct the ADC with the given quantiser and data points.
*
* @param pq
* the Product Quantiser
* @param dataPoints
* the data points to index
*/
public IncrementalIntADCNearestNeighbours(IntProductQuantiser pq, DataSource dataPoints) {
this.pq = pq;
this.ndims = dataPoints.getData(0).length;
final int size = dataPoints.size();
this.data = new ArrayList(size);
for (int i = 0; i < size; i++) {
data.add(pq.quantise(dataPoints.getData(i)));
}
}
/**
* Construct an empty ADC with the given quantiser.
*
* @param pq
* the Product Quantiser
* @param ndims
* the data dimensionality
*/
public IncrementalIntADCNearestNeighbours(IntProductQuantiser pq, int ndims) {
this.pq = pq;
this.ndims = ndims;
this.data = new ArrayList();
}
/**
* Construct an empty ADC with the given quantiser.
*
* @param pq
* the Product Quantiser
* @param ndims
* the data dimensionality
* @param nitems
* the expected number of data items
*/
public IncrementalIntADCNearestNeighbours(IntProductQuantiser pq, int ndims, int nitems) {
this.pq = pq;
this.ndims = ndims;
this.data = new ArrayList(nitems);
}
@Override
public int[] addAll(List d) {
final int[] indexes = new int[d.size()];
for (int i = 0; i < indexes.length; i++) {
indexes[i] = add(d.get(i));
}
return indexes;
}
@Override
public int add(int[] o) {
final int ret = data.size();
data.add(pq.quantise(o));
return ret;
}
@Override
public int numDimensions() {
return ndims;
}
@Override
public int size() {
return data.size();
}
@Override
public void readBinary(DataInput in) throws IOException {
pq = IOUtils.read(in);
ndims = in.readInt();
int size = in.readInt();
int dim = pq.assigners.length;
data = new ArrayList(size);
for (int i=0; i queue =
new BoundedPriorityQueue(1, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
//prepare working data
List list = new ArrayList(2);
list.add(new IntFloatPair());
list.add(new IntFloatPair());
for (int n=0; n < N; ++n) {
List result = search(qus[n], queue, list);
final IntFloatPair p = result.get(0);
indices[n] = p.first;
distances[n] = p.second;
}
}
@Override
public void searchKNN(final int [][] qus, int K, int [][] indices, float [][] distances) {
// Fix for when the user asks for too many points.
K = Math.min(K, data.size());
final int N = qus.length;
final BoundedPriorityQueue queue =
new BoundedPriorityQueue(K, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
//prepare working data
List list = new ArrayList(K + 1);
for (int i = 0; i < K + 1; i++) {
list.add(new IntFloatPair());
}
// search on each query
for (int n = 0; n < N; ++n) {
List result = search(qus[n], queue, list);
for (int k = 0; k < K; ++k) {
final IntFloatPair p = result.get(k);
indices[n][k] = p.first;
distances[n][k] = p.second;
}
}
}
@Override
public void searchNN(final List qus, int [] indices, float [] distances) {
final int N = qus.size();
final BoundedPriorityQueue queue =
new BoundedPriorityQueue(1, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
//prepare working data
List list = new ArrayList(2);
list.add(new IntFloatPair());
list.add(new IntFloatPair());
for (int n=0; n < N; ++n) {
List result = search(qus.get(n), queue, list);
final IntFloatPair p = result.get(0);
indices[n] = p.first;
distances[n] = p.second;
}
}
@Override
public void searchKNN(final List qus, int K, int [][] indices, float [][] distances) {
// Fix for when the user asks for too many points.
K = Math.min(K, data.size());
final int N = qus.size();
final BoundedPriorityQueue queue =
new BoundedPriorityQueue(K, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
//prepare working data
List list = new ArrayList(K + 1);
for (int i = 0; i < K + 1; i++) {
list.add(new IntFloatPair());
}
// search on each query
for (int n = 0; n < N; ++n) {
List result = search(qus.get(n), queue, list);
for (int k = 0; k < K; ++k) {
final IntFloatPair p = result.get(k);
indices[n][k] = p.first;
distances[n][k] = p.second;
}
}
}
@Override
public List searchKNN(int[] query, int K) {
// Fix for when the user asks for too many points.
K = Math.min(K, data.size());
final BoundedPriorityQueue queue =
new BoundedPriorityQueue(K, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
//prepare working data
List list = new ArrayList(K + 1);
for (int i = 0; i < K + 1; i++) {
list.add(new IntFloatPair());
}
// search
return search(query, queue, list);
}
@Override
public IntFloatPair searchNN(final int[] query) {
final BoundedPriorityQueue queue =
new BoundedPriorityQueue(1, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
//prepare working data
List list = new ArrayList(2);
list.add(new IntFloatPair());
list.add(new IntFloatPair());
return search(query, queue, list).get(0);
}
private List search(int[] query, BoundedPriorityQueue queue, List results) {
IntFloatPair wp = null;
// reset all values in the queue to MAX, -1
for (final IntFloatPair p : results) {
p.second = Float.MAX_VALUE;
p.first = -1;
wp = queue.offerItem(p);
}
// perform the search
computeDistances(query, queue, wp);
return queue.toOrderedListDestructive();
}
protected void computeDistances(int[] fullQuery, BoundedPriorityQueue queue, IntFloatPair wp) {
final float[][] distances = new float[pq.assigners.length][];
for (int j = 0, from = 0; j < this.pq.assigners.length; j++) {
final IntNearestNeighbours nn = this.pq.assigners[j];
final int to = nn.numDimensions();
final int K = nn.size();
final int[][] qus = { Arrays.copyOfRange(fullQuery, from, from + to) };
final int[][] idx = new int[1][K];
final float[][] dst = new float[1][K];
nn.searchKNN(qus, K, idx, dst);
distances[j] = new float[K];
for (int k = 0; k < K; k++) {
distances[j][idx[0][k]] = dst[0][k];
}
from += to;
}
final int size = data.size();
for (int i = 0; i < size; i++) {
wp.first = i;
wp.second = 0;
for (int j = 0; j < this.pq.assigners.length; j++) {
final int centroid = this.data.get(i)[j] + 128;
wp.second += distances[j][centroid];
}
wp = queue.offerItem(wp);
}
}
}