All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.openimaj.knn.approximate.DoubleKDTreeEnsemble Maven / Gradle / Ivy

/*
	AUTOMATICALLY GENERATED BY jTemp FROM
	/Users/jsh2/Work/openimaj/target/checkout/machine-learning/nearest-neighbour/src/main/jtemp/org/openimaj/knn/approximate/#T#KDTreeEnsemble.jtemp
*/
/**
 * Copyright (c) 2011, The University of Southampton and the individual contributors.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without modification,
 * are permitted provided that the following conditions are met:
 *
 *   * 	Redistributions of source code must retain the above copyright notice,
 * 	this list of conditions and the following disclaimer.
 *
 *   *	Redistributions in binary form must reproduce the above copyright notice,
 * 	this list of conditions and the following disclaimer in the documentation
 * 	and/or other materials provided with the distribution.
 *
 *   *	Neither the name of the University of Southampton nor the names of its
 * 	contributors may be used to endorse or promote products derived from this
 * 	software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */
package org.openimaj.knn.approximate;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;

import cern.jet.random.Uniform;
import cern.jet.random.engine.MersenneTwister;
    
import org.openimaj.knn.DoubleNearestNeighbours;
import org.openimaj.util.array.IntArrayView;
import org.openimaj.util.pair.*;

import jal.objects.BinaryPredicate;
import jal.objects.Sorting;

/**
 * Ensemble of Best-Bin-First KDTrees for double data.
 * 
 * @author Jonathon Hare ([email protected])
 * @author Sina Samangooei ([email protected])
 */
public class DoubleKDTreeEnsemble {
	private static final int leaf_max_points = 14;
	private static final int varest_max_points = 128;
	private static final int varest_max_randsz = 5;
	
	Uniform rng;

    /**
 	 * An internal node of the KDTree
 	 */	
	public static class DoubleKDTreeNode {
		class NodeData {}
		
	    class InternalNodeData extends NodeData {
	    	DoubleKDTreeNode right;
	        double disc;
	        int disc_dim;
	    }
	    
	    class LeafNodeData extends NodeData {
	        int [] indices;
	    }
	    
	    /**
	     * left == null iff this node is a leaf.
	     */
	    DoubleKDTreeNode left;

	    NodeData node_data;
	    
	    private Uniform rng;
	    
	    boolean is_leaf() { 
	    	return left==null; 
	    }

	    IntDoublePair choose_split(final double [][] pnts, final IntArrayView inds) {
	    	int D = pnts[0].length;
	    	
	        // Find mean & variance of each dimension.
	    	double [] sum_x = new double[D];
	    	double [] sum_xx = new double[D];
	        
	        int count = Math.min(inds.size(), varest_max_points);
	        for (int n=0; n p2.first) return true;
					if (p2.first > p1.first) return false;
					return (p1.second > p2.second);
				}});
	        
	        int randd = var_dim[rng.nextIntFromTo(0, nrand-1)].second;
	        
	        return new IntDoublePair(randd, sum_x[randd]/count);
	    }

	    void split_points(final double [][] pnts, IntArrayView inds) {
	        IntDoublePair spl = choose_split(pnts, inds);

	        ((InternalNodeData)node_data).disc_dim = spl.first;
	        ((InternalNodeData)node_data).disc = spl.second;

	        int N = inds.size();
	        int l = 0;
	        int r = N;
	        while (l!=r) {
	          if (pnts[inds.getFast(l)][((InternalNodeData)node_data).disc_dim] < ((InternalNodeData)node_data).disc) l++;
	          else {
	            r--;
	            int t = inds.getFast(l);
	            inds.setFast(l, inds.getFast(r));
	            inds.setFast(r, t);
	          }
	        }
	    
	        // If either partition is empty -> vectors identical!
	        if (l==0 || l==N) { l = N/2; } // The vectors are identical, so keep nlogn performance.

	        left = new DoubleKDTreeNode(pnts, inds.subView(0, l), rng);
	        
	        ((InternalNodeData)node_data).right = new DoubleKDTreeNode(pnts, inds.subView(l, N), rng);
	    }

		/** Construct a new node */
	    public DoubleKDTreeNode() { }

		/** 
		 * Construct a new node with the given data
		 *
		 * @param pnts the data for the node and its children
		 * @param inds a list of indices that point to the relevant
		 *			parts of the pnts array that should be used
		 * @param rng the random number generator
		 */
	    public DoubleKDTreeNode(final double [][] pnts, IntArrayView inds, Uniform rng) {
	    	this.rng = rng;
	        if (inds.size() > leaf_max_points) { // Internal node
	        	node_data = new InternalNodeData();
	            split_points(pnts, inds);
	        }
	        else {
	        	node_data = new LeafNodeData();
	        	((LeafNodeData)node_data).indices = inds.toArray();
	        }
	    }

	    void search(final double [] qu, PriorityQueue> pri_branch, List nns, boolean[] seen, double [][] pnts, double mindsq)
	    {
	    	DoubleKDTreeNode cur = this;
	    	DoubleKDTreeNode other = null;

	        while (!cur.is_leaf()) { // Follow best bin first until we hit a leaf
	        	double diff = qu[((InternalNodeData)cur.node_data).disc_dim] - ((InternalNodeData)cur.node_data).disc;

	            if (diff < 0) {
	                other = ((InternalNodeData)cur.node_data).right;
	                cur = cur.left;
	            }
	            else {
	                other = cur.left;
	                cur = ((InternalNodeData)cur.node_data).right;
	            }

	            pri_branch.add(new DoubleObjectPair(mindsq + diff*diff, other));
	        }

	        int [] cur_inds = ((LeafNodeData)cur.node_data).indices;
	        int ncur_inds = cur_inds.length;
	        
	        int i;
	        double [] dsq = new double[1];
	        for (i = 0; i < ncur_inds; ++i) {
	        	int ci = cur_inds[i];
	            if (!seen[ci]) {
	            	DoubleNearestNeighbours.distanceFunc(qu, new double[][] {pnts[ci]}, dsq);
	                
	                nns.add(new IntDoublePair(ci, dsq[0]));
	                
	                seen[ci] = true;
	            }
	        }
	    }
	}
	
	/** The tree roots */ 
	public final DoubleKDTreeNode [] trees;
	
	/** The underlying data array */
	public final double [][] pnts;
    
    /**
     * Construct a DoubleKDTreeEnsemble with the provided data,
     * using the default of 8 trees.
     * @param pnts the data array 
     */
    public DoubleKDTreeEnsemble(final double [][] pnts) {
    	this(pnts, 8, 42);
    }
    
    /**
     * Construct a DoubleKDTreeEnsemble with the provided data and
     * number of trees.
     * @param pnts the data array 
     * @param ntrees the number of KDTrees in the ensemble 
     */
    public DoubleKDTreeEnsemble(final double [][] pnts, int ntrees) {
    	this(pnts, ntrees, 42);
    }
    
    /**
     * Construct a DoubleKDTreeEnsemble with the provided data and
     * number of trees.
     * @param pnts the data array 
     * @param ntrees the number of KDTrees in the ensemble
     * @param seed the seed for the random number generator used in 
     *			tree construction 
     */
    public DoubleKDTreeEnsemble(final double [][] pnts, int ntrees, int seed) {
    	final int N = pnts.length;
    	this.pnts = pnts;
    	this.rng = new Uniform(new MersenneTwister(seed));

        // Create inds.
    	IntArrayView inds = new IntArrayView(N);
        for (int n=0; n N) nchecks = N;
        
        PriorityQueue> pri_branch = new PriorityQueue>(
        	11, 
        	new Comparator>() {
        		@Override
        		public int compare(DoubleObjectPair o1, DoubleObjectPair o2) {
        			if (o1.first > o2.first) return 1;
        			if (o2.first > o1.first) return -1;
        			return 0;
        		}}
        );

        List nns = new ArrayList((3*nchecks)/2);
        boolean [] seen = new boolean[N];

        // Search each tree at least once.
        for (int t=0; t pr = pri_branch.poll();
            
            pr.second.search(qu, pri_branch, nns, seen, pnts, pr.first);
        }

        IntDoublePair [] nns_arr = nns.toArray(new IntDoublePair[nns.size()]); 
        Sorting.partial_sort(nns_arr, 0, numnn, nns_arr.length, new BinaryPredicate() {
			@Override
			public boolean apply(Object lhs, Object rhs) {
				return ((IntDoublePair)lhs).second < ((IntDoublePair)rhs).second;
			}});

        System.arraycopy(nns_arr, 0, ret_nns, 0, Math.min(numnn, nchecks));
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy