org.openimaj.knn.approximate.DoubleKDTreeEnsemble Maven / Gradle / Ivy
/*
AUTOMATICALLY GENERATED BY jTemp FROM
/Users/jsh2/Work/openimaj/target/checkout/machine-learning/nearest-neighbour/src/main/jtemp/org/openimaj/knn/approximate/#T#KDTreeEnsemble.jtemp
*/
/**
* Copyright (c) 2011, The University of Southampton and the individual contributors.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* * Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* * Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* * Neither the name of the University of Southampton nor the names of its
* contributors may be used to endorse or promote products derived from this
* software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
* ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package org.openimaj.knn.approximate;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import cern.jet.random.Uniform;
import cern.jet.random.engine.MersenneTwister;
import org.openimaj.knn.DoubleNearestNeighbours;
import org.openimaj.util.array.IntArrayView;
import org.openimaj.util.pair.*;
import jal.objects.BinaryPredicate;
import jal.objects.Sorting;
/**
* Ensemble of Best-Bin-First KDTrees for double data.
*
* @author Jonathon Hare ([email protected])
* @author Sina Samangooei ([email protected])
*/
public class DoubleKDTreeEnsemble {
private static final int leaf_max_points = 14;
private static final int varest_max_points = 128;
private static final int varest_max_randsz = 5;
Uniform rng;
/**
* An internal node of the KDTree
*/
public static class DoubleKDTreeNode {
class NodeData {}
class InternalNodeData extends NodeData {
DoubleKDTreeNode right;
double disc;
int disc_dim;
}
class LeafNodeData extends NodeData {
int [] indices;
}
/**
* left == null iff this node is a leaf.
*/
DoubleKDTreeNode left;
NodeData node_data;
private Uniform rng;
boolean is_leaf() {
return left==null;
}
IntDoublePair choose_split(final double [][] pnts, final IntArrayView inds) {
int D = pnts[0].length;
// Find mean & variance of each dimension.
double [] sum_x = new double[D];
double [] sum_xx = new double[D];
int count = Math.min(inds.size(), varest_max_points);
for (int n=0; n p2.first) return true;
if (p2.first > p1.first) return false;
return (p1.second > p2.second);
}});
int randd = var_dim[rng.nextIntFromTo(0, nrand-1)].second;
return new IntDoublePair(randd, sum_x[randd]/count);
}
void split_points(final double [][] pnts, IntArrayView inds) {
IntDoublePair spl = choose_split(pnts, inds);
((InternalNodeData)node_data).disc_dim = spl.first;
((InternalNodeData)node_data).disc = spl.second;
int N = inds.size();
int l = 0;
int r = N;
while (l!=r) {
if (pnts[inds.getFast(l)][((InternalNodeData)node_data).disc_dim] < ((InternalNodeData)node_data).disc) l++;
else {
r--;
int t = inds.getFast(l);
inds.setFast(l, inds.getFast(r));
inds.setFast(r, t);
}
}
// If either partition is empty -> vectors identical!
if (l==0 || l==N) { l = N/2; } // The vectors are identical, so keep nlogn performance.
left = new DoubleKDTreeNode(pnts, inds.subView(0, l), rng);
((InternalNodeData)node_data).right = new DoubleKDTreeNode(pnts, inds.subView(l, N), rng);
}
/** Construct a new node */
public DoubleKDTreeNode() { }
/**
* Construct a new node with the given data
*
* @param pnts the data for the node and its children
* @param inds a list of indices that point to the relevant
* parts of the pnts array that should be used
* @param rng the random number generator
*/
public DoubleKDTreeNode(final double [][] pnts, IntArrayView inds, Uniform rng) {
this.rng = rng;
if (inds.size() > leaf_max_points) { // Internal node
node_data = new InternalNodeData();
split_points(pnts, inds);
}
else {
node_data = new LeafNodeData();
((LeafNodeData)node_data).indices = inds.toArray();
}
}
void search(final double [] qu, PriorityQueue> pri_branch, List nns, boolean[] seen, double [][] pnts, double mindsq)
{
DoubleKDTreeNode cur = this;
DoubleKDTreeNode other = null;
while (!cur.is_leaf()) { // Follow best bin first until we hit a leaf
double diff = qu[((InternalNodeData)cur.node_data).disc_dim] - ((InternalNodeData)cur.node_data).disc;
if (diff < 0) {
other = ((InternalNodeData)cur.node_data).right;
cur = cur.left;
}
else {
other = cur.left;
cur = ((InternalNodeData)cur.node_data).right;
}
pri_branch.add(new DoubleObjectPair(mindsq + diff*diff, other));
}
int [] cur_inds = ((LeafNodeData)cur.node_data).indices;
int ncur_inds = cur_inds.length;
int i;
double [] dsq = new double[1];
for (i = 0; i < ncur_inds; ++i) {
int ci = cur_inds[i];
if (!seen[ci]) {
DoubleNearestNeighbours.distanceFunc(qu, new double[][] {pnts[ci]}, dsq);
nns.add(new IntDoublePair(ci, dsq[0]));
seen[ci] = true;
}
}
}
}
/** The tree roots */
public final DoubleKDTreeNode [] trees;
/** The underlying data array */
public final double [][] pnts;
/**
* Construct a DoubleKDTreeEnsemble with the provided data,
* using the default of 8 trees.
* @param pnts the data array
*/
public DoubleKDTreeEnsemble(final double [][] pnts) {
this(pnts, 8, 42);
}
/**
* Construct a DoubleKDTreeEnsemble with the provided data and
* number of trees.
* @param pnts the data array
* @param ntrees the number of KDTrees in the ensemble
*/
public DoubleKDTreeEnsemble(final double [][] pnts, int ntrees) {
this(pnts, ntrees, 42);
}
/**
* Construct a DoubleKDTreeEnsemble with the provided data and
* number of trees.
* @param pnts the data array
* @param ntrees the number of KDTrees in the ensemble
* @param seed the seed for the random number generator used in
* tree construction
*/
public DoubleKDTreeEnsemble(final double [][] pnts, int ntrees, int seed) {
final int N = pnts.length;
this.pnts = pnts;
this.rng = new Uniform(new MersenneTwister(seed));
// Create inds.
IntArrayView inds = new IntArrayView(N);
for (int n=0; n N) nchecks = N;
PriorityQueue> pri_branch = new PriorityQueue>(
11,
new Comparator>() {
@Override
public int compare(DoubleObjectPair o1, DoubleObjectPair o2) {
if (o1.first > o2.first) return 1;
if (o2.first > o1.first) return -1;
return 0;
}}
);
List nns = new ArrayList((3*nchecks)/2);
boolean [] seen = new boolean[N];
// Search each tree at least once.
for (int t=0; t pr = pri_branch.poll();
pr.second.search(qu, pri_branch, nns, seen, pnts, pr.first);
}
IntDoublePair [] nns_arr = nns.toArray(new IntDoublePair[nns.size()]);
Sorting.partial_sort(nns_arr, 0, numnn, nns_arr.length, new BinaryPredicate() {
@Override
public boolean apply(Object lhs, Object rhs) {
return ((IntDoublePair)lhs).second < ((IntDoublePair)rhs).second;
}});
System.arraycopy(nns_arr, 0, ret_nns, 0, Math.min(numnn, nchecks));
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy