org.openimaj.knn.DoubleNearestNeighboursExact Maven / Gradle / Ivy
/*
AUTOMATICALLY GENERATED BY jTemp FROM
/Users/jsh2/Work/openimaj/target/checkout/machine-learning/nearest-neighbour/src/main/jtemp/org/openimaj/knn/#T#NearestNeighboursExact.jtemp
*/
/**
* Copyright (c) 2011, The University of Southampton and the individual contributors.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* * Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* * Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* * Neither the name of the University of Southampton nor the names of its
* contributors may be used to endorse or promote products derived from this
* software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
* ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package org.openimaj.knn;
import java.util.ArrayList;
import java.util.List;
import org.openimaj.feature.DoubleFVComparison;
import org.openimaj.feature.DoubleFVComparator;
import org.openimaj.util.pair.IntDoublePair;
import org.openimaj.util.queue.BoundedPriorityQueue;
/**
* Exact (brute-force) k-nearest-neighbour implementation.
*
* @author Jonathon Hare ([email protected])
* @author Sina Samangooei ([email protected])
*/
public class DoubleNearestNeighboursExact extends DoubleNearestNeighbours {
/**
* {@link NearestNeighboursFactory} for producing
* {@link DoubleNearestNeighboursExact}s.
*
* @author Jonathon Hare ([email protected])
*/
public static final class Factory implements NearestNeighboursFactory {
private final DoubleFVComparator distance;
/**
* Construct the factory using Euclidean distance for the
* produced DoubleNearestNeighbours instances.
*/
public Factory() {
this.distance = null;
}
/**
* Construct the factory with the given distance function
* for the produced DoubleNearestNeighbours instances.
*
* @param distance
* the distance function
*/
public Factory(DoubleFVComparator distance) {
this.distance = distance;
}
@Override
public DoubleNearestNeighboursExact create(double[][] data) {
return new DoubleNearestNeighboursExact(data, distance);
}
}
protected final double[][] pnts;
protected final DoubleFVComparator distance;
/**
* Construct the DoubleNearestNeighboursExact over the provided
* dataset and using Euclidean distance.
* @param pnts the dataset
*/
public DoubleNearestNeighboursExact(final double [][] pnts) {
this(pnts, null);
}
/**
* Construct the DoubleNearestNeighboursExact over the provided
* dataset with the given distance function.
*
* Note: If the distance function provides similarities rather
* than distances they are automatically inverted.
*
* @param pnts the dataset
* @param distance the distance function
*/
public DoubleNearestNeighboursExact(final double [][] pnts, final DoubleFVComparator distance) {
this.pnts = pnts;
this.distance = distance;
}
@Override
public void searchNN(final double [][] qus, int [] indices, double [] distances) {
final int N = qus.length;
final BoundedPriorityQueue queue =
new BoundedPriorityQueue(1, IntDoublePair.SECOND_ITEM_ASCENDING_COMPARATOR);
//prepare working data
List list = new ArrayList(2);
list.add(new IntDoublePair());
list.add(new IntDoublePair());
for (int n=0; n < N; ++n) {
List result = search(qus[n], queue, list);
final IntDoublePair p = result.get(0);
indices[n] = p.first;
distances[n] = p.second;
}
}
@Override
public void searchKNN(final double [][] qus, int K, int [][] indices, double [][] distances) {
// Fix for when the user asks for too many points.
K = Math.min(K, pnts.length);
final int N = qus.length;
final BoundedPriorityQueue queue =
new BoundedPriorityQueue(K, IntDoublePair.SECOND_ITEM_ASCENDING_COMPARATOR);
//prepare working data
List list = new ArrayList(K + 1);
for (int i = 0; i < K + 1; i++) {
list.add(new IntDoublePair());
}
// search on each query
for (int n = 0; n < N; ++n) {
List result = search(qus[n], queue, list);
for (int k = 0; k < K; ++k) {
final IntDoublePair p = result.get(k);
indices[n][k] = p.first;
distances[n][k] = p.second;
}
}
}
@Override
public void searchNN(final List qus, int [] indices, double [] distances) {
final int N = qus.size();
final BoundedPriorityQueue queue =
new BoundedPriorityQueue(1, IntDoublePair.SECOND_ITEM_ASCENDING_COMPARATOR);
//prepare working data
List list = new ArrayList(2);
list.add(new IntDoublePair());
list.add(new IntDoublePair());
for (int n=0; n < N; ++n) {
List result = search(qus.get(n), queue, list);
final IntDoublePair p = result.get(0);
indices[n] = p.first;
distances[n] = p.second;
}
}
@Override
public void searchKNN(final List qus, int K, int [][] indices, double [][] distances) {
// Fix for when the user asks for too many points.
K = Math.min(K, pnts.length);
final int N = qus.size();
final BoundedPriorityQueue queue =
new BoundedPriorityQueue(K, IntDoublePair.SECOND_ITEM_ASCENDING_COMPARATOR);
//prepare working data
List list = new ArrayList(K + 1);
for (int i = 0; i < K + 1; i++) {
list.add(new IntDoublePair());
}
// search on each query
for (int n = 0; n < N; ++n) {
List result = search(qus.get(n), queue, list);
for (int k = 0; k < K; ++k) {
final IntDoublePair p = result.get(k);
indices[n][k] = p.first;
distances[n][k] = p.second;
}
}
}
@Override
public List searchKNN(double[] query, int K) {
// Fix for when the user asks for too many points.
K = Math.min(K, pnts.length);
final BoundedPriorityQueue queue =
new BoundedPriorityQueue(K, IntDoublePair.SECOND_ITEM_ASCENDING_COMPARATOR);
//prepare working data
List list = new ArrayList(K + 1);
for (int i = 0; i < K + 1; i++) {
list.add(new IntDoublePair());
}
// search
return search(query, queue, list);
}
@Override
public IntDoublePair searchNN(final double[] query) {
final BoundedPriorityQueue queue =
new BoundedPriorityQueue(1, IntDoublePair.SECOND_ITEM_ASCENDING_COMPARATOR);
//prepare working data
List list = new ArrayList(2);
list.add(new IntDoublePair());
list.add(new IntDoublePair());
return search(query, queue, list).get(0);
}
private List search(double[] query, BoundedPriorityQueue queue, List results) {
IntDoublePair wp = null;
// reset all values in the queue to MAX, -1
for (final IntDoublePair p : results) {
p.second = Float.MAX_VALUE;
p.first = -1;
wp = queue.offerItem(p);
}
// perform the search
for (int i = 0; i < this.pnts.length; i++) {
wp.second = distanceFunc(distance, query, pnts[i]);
wp.first = i;
wp = queue.offerItem(wp);
}
return queue.toOrderedListDestructive();
}
@Override
public int numDimensions() {
return pnts[0].length;
}
@Override
public int size() {
return pnts.length;
}
/**
* Get the underlying data points.
*
* @return the data points
*/
public double[][] getPoints() {
return this.pnts;
}
/**
* Compute the distance between two vectors using the underlying distance
* comparison used by this class.
*
* @param a
* the first vector
* @param b
* the second vector
* @return the distance between the two vectors
*/
public double computeDistance(double[] a, double[] b) {
if (distance == null)
return (double) DoubleFVComparison.SUM_SQUARE.compare(a, b);
return (double) distance.compare(a, b);
}
/**
* Get the distance comparator
*
* @return the distance comparator
*/
public DoubleFVComparator distanceComparator() {
return this.distance;
}
}