All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.math.geometry.KDTree Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                KDTree.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright Jul 28, 2009, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
 * license for use of this work by or on behalf of the U.S. Government.
 * Export of this program may require a license from the United States
 * Government. See CopyrightHistory.txt for complete details.
 * 
 */

package gov.sandia.cognition.math.geometry;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.math.Metric;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.math.matrix.VectorizableIndexComparator;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Pair;
import java.util.AbstractCollection;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.PriorityQueue;

/**
 * Implementation of a kd-tree. Every node in a KDTree has a k-dimensional
 * Vectorizable point and an associated value as a generic "DataType." The At
 * each depth in the KDTree, the KDTree partitions the Vectorizables into two
 * sets along a particular dimension according to the Vectorizable stored in the
 * node value. The dimension used to partition at a particular depth is (depth %
 * k). Vectorizables with values less than or equal to the node-value-dimension
 * are placed into the left subtree, and those greater than the
 * node-value-dimension are stored into the right subtree. This makes
 * average-case nearest-neighbor lookup into a balanced KDTree with "N" points
 * of O(log(N)), rather than the typical "N" time for linear search.
 * Construction of a balanced KDTree with N points takes average-case O(N
 * log(N)).
 *
 * @param  Type of Vectorizable, the first values
 * @param  Type of data in the Pair, the second values
 * @param  Type of Pair to use in the KDTree.
 * @author Kevin R. Dixon
 * @since 3.0
 */
@PublicationReferences(
    references =
    {
        @PublicationReference(
            author = "Andrew W. Moore",
            title = "An intoductory tutorial on kd-trees",
            type = PublicationType.TechnicalReport,
            publication
            = "University of Cambridge Computer Laboratory Technical Report No. 209",
            year = 1991,
            url
            = "http://www.autonlab.org/autonweb/14665.html?branch=1&language=2"
        ),
        @PublicationReference(
            author = "Wikipedia",
            title = "kd-tree",
            type = PublicationType.WebPage,
            year = 2009,
            url = "http://en.wikipedia.org/wiki/Kd-tree"
        )
    }
)
public class KDTree>
    extends AbstractCollection
    implements CloneableSerializable
{

    /**
     * Number of elements in this subtree.
     */
    protected int num;

    /**
     * VectorType,DataType value for this node of the subtree.
     */
    protected PairType value;

    /**
     * Parent of this node of the subtree.
     */
    protected KDTree parent;

    /**
     * Left child of this subtree
     */
    protected KDTree leftChild;

    /**
     * Right child of this subtree.
     */
    protected KDTree rightChild;

    /**
     * Comparator of this node to determine less than, greater than, or
     * equality.
     */
    protected PairFirstVectorizableIndexComparator comparator;

    /**
     * Default constructor
     */
    public KDTree()
    {
        this(null, null, null);
        this.num = 0;
    }

    /**
     * Creates a balanced KDTree from the given points.
     *
     * @param points Points to load into the KDTree.
     */
    public KDTree(
        Collection points)
    {
        this(CollectionUtil.asArrayList(points),
            new PairFirstVectorizableIndexComparator(0),
            CollectionUtil.getFirst(points).getFirst().convertToVector().getDimensionality(),
            null);
    }

    /**
     * Creates a KDTree subtree for recursion purposes.
     *
     * @param value Value of the head of the subtree.
     * @param comparator Comparator to use for the Vectorizables.
     * @param parent Parent node of this subtree.
     */
    protected KDTree(
        PairType value,
        PairFirstVectorizableIndexComparator comparator,
        KDTree parent)
    {
        this.num = 1;
        this.value = value;
        this.comparator = comparator;
        this.parent = parent;
        this.leftChild = null;
        this.rightChild = null;
    }

    /**
     * Creates a balanced KDTree subtree for recursion purposes from the given
     * ArrayList of points. This is an O(n log n) operation for "n" points
     * because we use a clever linear-time kth selection algorithm in
     * CollectionUtil.findKthLargest().
     *
     * @param points Points to load into the subtree.
     * @param dimensionality Dimensionality of the Vectorizables.
     * @param comparator Comparator to use for the Vectorizables.
     * @param parent Parent node of this subtree.
     */
    protected KDTree(
        ArrayList points,
        PairFirstVectorizableIndexComparator comparator,
        int dimensionality,
        KDTree parent)
    {

        this.parent = parent;
        this.comparator = comparator;
        this.num = points.size();
        if (num <= 0)
        {
            throw new IllegalArgumentException("No points!");
        }
        else if (num == 1)
        {
            this.value = points.get(0);
        }
        else
        {
            final int medianIndex = this.num / 2;
            int[] indices = CollectionUtil.findKthLargest(
                medianIndex, points, comparator);

            // This is the median of the axis.
            this.value = points.get(indices[medianIndex]);

            final int childAxis = (this.comparator.comparator.getIndex() + 1)
                % dimensionality;
            PairFirstVectorizableIndexComparator childComparator
                = new PairFirstVectorizableIndexComparator(childAxis);

            // Left child recursion
            final int leftNum = medianIndex;
            if (leftNum > 0)
            {
                ArrayList leftPoints
                    = new ArrayList(leftNum);
                for (int i = 0; i < leftNum; i++)
                {
                    leftPoints.add(points.get(indices[i]));
                }
                this.leftChild = new KDTree(
                    leftPoints, childComparator, dimensionality, this);
            }

            // Right child recursion
            final int rightNum = num - medianIndex - 1;
            if (rightNum > 0)
            {
                ArrayList rightPoints = new ArrayList(
                    rightNum);
                for (int i = medianIndex + 1; i < this.num; i++)
                {
                    rightPoints.add(points.get(indices[i]));
                }
                this.rightChild = new KDTree(
                    rightPoints, childComparator, dimensionality, this);
            }

        }

    }

    @Override
    @SuppressWarnings("unchecked")
    public KDTree clone()
    {
        KDTree clone;
        try
        {
            clone = (KDTree) super.clone();
            clone.leftChild = ObjectUtil.cloneSafe(this.leftChild);
            clone.rightChild = ObjectUtil.cloneSafe(this.rightChild);
            clone.value = ObjectUtil.cloneSmart(this.value);
            clone.parent = this.parent;
            clone.comparator = this.comparator;
        }
        catch (CloneNotSupportedException ex)
        {
            clone = null;
        }

        return clone;

    }

    /**
     * Creates a balanced KDTree based on the given collection of Pairs. This is
     * an O(n log n) operation for "n" points because we use a clever
     * linear-time kth selection algorithm in CollectionUtil.findKthLargest().
     *
     * @param  Type of Vectorizable, the first values.
     * @param  Type of data in the Pair, the second values.
     * @param  Type of Pair to use in the KDTree.
     * @param points Points to load into the tree.
     * @return Balanced KDTree that contains all the given points.
     */
    public static >
        KDTree createBalanced(
            Collection points)
    {
        return new KDTree(points);
    }

    /**
     * Rebalances the KDTree. Does not modify this KDTree.
     *
     * @return Balanced representation of this KDTree.
     */
    public KDTree reblanace()
    {
        return createBalanced(this);
    }

    @Override
    public boolean add(
        PairType point)
    {

        if (this.value == null)
        {
            this.num = 1;
            this.value = point;
            this.comparator = new PairFirstVectorizableIndexComparator(0);
        }
        else
        {
            this.num++;
            int comparison = this.comparator.compare(point, this.value);
            if (comparison <= 0)
            {
                if (this.leftChild == null)
                {
                    int dimension
                        = point.getFirst().convertToVector().getDimensionality();
                    int childAxis = (this.comparator.comparator.getIndex() + 1)
                        % dimension;
                    PairFirstVectorizableIndexComparator childComparator
                        = new PairFirstVectorizableIndexComparator(childAxis);
                    this.leftChild = new KDTree(
                        point, childComparator, this);
                }
                else
                {
                    this.leftChild.add(point);
                }
            }
            else if (this.rightChild == null)
            {
                int dimension
                    = point.getFirst().convertToVector().getDimensionality();
                int childAxis = (this.comparator.comparator.getIndex() + 1)
                    % dimension;
                PairFirstVectorizableIndexComparator childComparator
                    = new PairFirstVectorizableIndexComparator(childAxis);
                this.rightChild = new KDTree(
                    point, childComparator, this);
            }
            else
            {
                this.rightChild.add(point);
            }
        }

        return true;

    }

    @Override
    public int size()
    {
        return this.num;
    }

    /**
     * Iterates through the KDTree using "inorder", also known as "symmetric
     * traversal", of the tree. That is, the recursion proceeds as traverse the
     * left subtree, visit the node, traverse the right subtree.
     *
     * @return Inorder iterator of the KDTree.
     */
    @PublicationReference(
        author = "Wikipedia",
        title = "Tree traversal",
        type = PublicationType.WebPage,
        year = 2009,
        url = "http://en.wikipedia.org/wiki/Tree_traversal#Traversal"
    )
    @Override
    public Iterator iterator()
    {
        return new InOrderKDTreeIterator(this);
    }

    @Override
    public String toString()
    {
        return this.toString("Head->");
    }

    /**
     * Recursively prints out the tree "inorder" by printing out the left
     * subtree, then the node, then the right subtree.
     *
     * @param prefix Prefix to tack onto the recursion values.
     * @return String representation of the KDTree.
     */
    protected String toString(
        String prefix)
    {

        String retval = prefix + " (" + this.value.getFirst() + " -> "
            + this.value.getSecond() + ")\n";

        if (this.leftChild != null)
        {
            retval += this.leftChild.toString(prefix + "L");
        }

        if (this.rightChild != null)
        {
            retval += this.rightChild.toString(prefix + "R");
        }

        return retval;

    }

    /**
     * Finds the "num" nearest neighbors to the given "key" stored in the
     * KDTree.
     *
     * @param key Vector to find the nearest neighbors of.
     * @param k Number of neighbors to find.
     * @param metric Metric to use to evaluate the nearness of other points.
     * @return Collection of nearest points to the "key" query. If "num" is
     * greater than or equal to the number of points in the KDTRee, then the
     * KDTree is returned.
     */
    public Collection findNearest(
        VectorType key,
        int k,
        Metric metric)
    {

        if (k < this.size())
        {
            Neighborhood neighborhood
                = new Neighborhood(k);
            this.findNearest(key, k, neighborhood, metric);
            return neighborhood;
        }
        else
        {
            return this;
        }

    }

    /**
     * Finds the "num" nearest neighbors to the given "key" stored in the
     * KDTree.
     *
     * @param key Vector to find the nearest neighbors of.
     * @param k Number of neighbors to find.
     * @param neighborhood PriorityQueue to store the current nearest neighbors.
     * @param metric Metric to use to evaluate the nearness of other points.
     */
    protected void findNearest(
        VectorType key,
        int k,
        Neighborhood neighborhood,
        Metric metric)
    {

        KDTree closer = null;
        KDTree further = null;

        // If we've got children, then see which child is closer
        if ((this.leftChild != null) || (this.rightChild != null))
        {
            int comparison = this.comparator.comparator.compare(
                key, this.value.getFirst());

            if (comparison <= 0)
            {
                closer = this.leftChild;
                further = this.rightChild;
            }
            else
            {
                closer = this.rightChild;
                further = this.leftChild;
            }

            // recurse into the closer subtree if it exists.
            if (closer != null)
            {
                closer.findNearest(key, k, neighborhood, metric);
            }

        }

        // If there's space in the queue, then add our value.
        if (!neighborhood.isFull())
        {
            // Compute our distance to the key
            double distance = metric.evaluate(this.value.getFirst(), key);
            neighborhood.add(this.value, distance);

            // If there's still space, then recurse to the further tree.
            if (further != null)
            {
//                if( !neighborhood.isFull() )
                {
                    further.findNearest(key, k, neighborhood, metric);
//                    further.findNearest(key, num, neighborhood, metric);
                }
            }

        }

        // The queue is full, so we need to see if we're closer than
        // the furthest neighbor.
        else
        {

            // We need to see if it's possible that the BOTH us and the
            // further subtree could contain a better point than the furthest
            // neighbor so far.
            double minimumDistance = this.computeMinimumDifference(key);
            if (minimumDistance < neighborhood.getFurthestNeighborDistance())
            {
                double distance = metric.evaluate(this.value.getFirst(), key);
                neighborhood.offer(this.value, distance);
                if (further != null)
                {
                    further.findNearest(key, num, neighborhood, metric);
                }

            }
        }

    }

    /**
     * Finds the neighbors within a given distance to the given "key" stored in
     * the KDTree.
     *
     * @param key Vector to find the nearest neighbors of.
     * @param radius Radius of desired neighborhood.
     * @param metric Metric to use to evaluate the nearness of other points.
     * @return Collection of points within a given distance to the "key" query.
     */
    public Collection findNearestWithinRadius(
        VectorType key,
        double radius,
        Metric metric)
    {
        Neighborhood neighborhood
            = new Neighborhood(this.size());
        this.findNearestWithinRadius(key, radius, neighborhood, metric);
        return neighborhood;
    }

    /**
     * Finds the neighbors within a given distance to the given "key" stored in
     * the KDTree.
     *
     * @param key Vector to find the nearest neighbors of.
     * @param radius Radius of desired neighborhood.
     * @param neighborhood PriorityQueue to store the neighbors.
     * @param metric Metric to use to evaluate the nearness of other points.
     */
    protected void findNearestWithinRadius(
        VectorType key,
        double radius,
        Neighborhood neighborhood,
        Metric metric)
    {
        KDTree closer = null;
        KDTree further = null;

        // If we've got children, then see which child is closer
        if ((this.leftChild != null) || (this.rightChild != null))
        {
            int comparison = this.comparator.comparator.compare(
                key, this.value.getFirst());

            if (comparison <= 0)
            {
                closer = this.leftChild;
                further = this.rightChild;
            }
            else
            {
                closer = this.rightChild;
                further = this.leftChild;
            }

            // recurse into the closer subtree if it exists.
            if (closer != null)
            {
                closer.findNearestWithinRadius(key, radius, neighborhood, metric);
            }

        }

        // Compute our distance to the key
        double distance = metric.evaluate(this.value.getFirst(), key);
        if (distance <= radius)
        {
            // Add this node if the distance is less than the radius.
            neighborhood.add(this.value, distance);
            // Only investigate the further tree if this node was within radius.
            if (further != null)
            {
                further.findNearestWithinRadius(key, radius, neighborhood,
                    metric);
            }
        }
    }

    /**
     * Computes the minimum absolute difference between the given key and the
     * "first" value stored in this subtree for the index given by the embedded
     * comparator. That is, the minimum distance "this is done by intersecting
     * the splitting hyperplane with a hypersphere around the search node [key]
     * that has a radius equal to the current nearest distance. Since the
     * hyperplanes are all axis-aligned this is implemented as a simple
     * comparison to see whether the difference between the splitting coordinate
     * and the search point is less than the distance from the search point to
     * the current best."
     *
     * @param key Vector to compare against.
     * @return Minimum absolute difference for the given index between the key
     * and the first value stored in this subtree.
     */
    protected double computeMinimumDifference(
        VectorType key)
    {
        int index = this.comparator.comparator.getIndex();
        double delta = key.convertToVector().getElement(index)
            - this.value.getFirst().convertToVector().getElement(index);
        return Math.abs(delta);
    }

    /**
     * Comparator for Pairs that have a Vectorizable as its first parameter.
     */
    protected static class PairFirstVectorizableIndexComparator
        extends AbstractCloneableSerializable
        implements Comparator>
    {

        /**
         * Embedded comparator for the Vectorizable argument.
         */
        public VectorizableIndexComparator comparator;

        /**
         * Creates a new instance of PairFirstVectorizableIndexComparator
         *
         * @param index Index of the Vectorizable to compare against.
         */
        public PairFirstVectorizableIndexComparator(
            int index)
        {
            this.comparator = new VectorizableIndexComparator(index);
        }

        public int compare(
            Pair o1,
            Pair o2)
        {
            return this.comparator.compare(o1.getFirst(), o2.getFirst());
        }

    }

    /**
     * Iterates through the KDTree using "inorder", also known as "symmetric
     * traversal", of the tree. That is, the recursion proceeds as traverse the
     * left subtree, visit the node, traverse the right subtree.
     *
     * @param  Type of Vectorizable, the first values
     * @param  Type of data in the Pair, the second values
     * @param  Type of Pair to use in the KDTree.
     */
    @PublicationReference(
        author = "Wikipedia",
        title = "Tree traversal",
        type = PublicationType.WebPage,
        year = 2009,
        url = "http://en.wikipedia.org/wiki/Tree_traversal#Traversal"
    )
    protected static class InOrderKDTreeIterator>
        implements Iterator
    {

        /**
         * Value of the node
         */
        public PairType nodeValue;

        /**
         * Iterator for the left subtree.
         */
        public InOrderKDTreeIterator leftIterator;

        /**
         * Iterator for the right subtree.
         */
        public InOrderKDTreeIterator rightIterator;

        /**
         * Creates a new instance of InOrderKDTreeIterator
         *
         * @param node Node from which to iterate.
         */
        public InOrderKDTreeIterator(
            KDTree node)
        {
            if (node.leftChild != null)
            {
                this.leftIterator
                    = new InOrderKDTreeIterator(
                        node.leftChild);
            }
            if (node.rightChild != null)
            {
                this.rightIterator
                    = new InOrderKDTreeIterator(
                        node.rightChild);
            }

            this.nodeValue = node.value;

        }

        public boolean hasNext()
        {
            return (this.nodeValue != null) || ((this.rightIterator != null)
                && (this.rightIterator.hasNext())) || ((this.leftIterator
                != null) && (this.leftIterator.hasNext()));
        }

        public PairType next()
        {

            PairType retval = null;

            if ((this.leftIterator != null)
                && this.leftIterator.hasNext())
            {
                retval = this.leftIterator.next();
            }
            else
            {
                this.leftIterator = null;

                if (this.nodeValue != null)
                {
                    retval = this.nodeValue;
                    this.nodeValue = null;
                }
                else if (this.rightIterator != null)
                {
                    if (this.rightIterator.hasNext())
                    {
                        retval = this.rightIterator.next();
                    }
                    else
                    {
                        this.rightIterator = null;
                    }
                }
            }

            if (retval == null)
            {
                throw new IllegalArgumentException(
                    "Should not have called null since we have no values to iterate!");
            }

            return retval;

        }

        public void remove()
        {
            throw new UnsupportedOperationException("Not supported yet.");
        }

    }

    /**
     * A Collection of nearby pairs.
     *
     * @param  Type of Vectorizable.
     * @param  Type of output data.
     * @param  Type of Pair.
     */
    protected static class Neighborhood>
        extends AbstractCollection
    {

        /**
         * Maximum number of Neighbors in the Neighborhood.
         */
        private int k;

        /**
         * PriorityQueue to store the neighbors.
         */
        PriorityQueue> priorityQueue;

        /**
         * Creates a new Neighborhood.
         *
         * @param k Maximum number of Neighbors in the Neighborhood.
         */
        public Neighborhood(
            int k)
        {
            this.priorityQueue
                = new PriorityQueue>(k);
            this.k = k;
        }

        /**
         * Returns true if the Neighborhood is full.
         *
         * @return True if the Neighborhood is full, false if not full.
         */
        public boolean isFull()
        {
            return this.size() >= this.k;
        }

        /**
         * Returns the distance of the furthest Neighbor.
         *
         * @return Distance of the furthest Neighbor.
         */
        public double getFurthestNeighborDistance()
        {
            return this.priorityQueue.peek().distance;
        }

        /**
         * Adds the neighbor to the priority queue.
         *
         * @param value Value to add.
         * @param distance Distance to associate with the neighbor to the queue.
         */
        public void add(
            PairType value,
            double distance)
        {

            while (this.isFull())
            {
                this.priorityQueue.remove();
            }

            this.priorityQueue.add(
                new Neighbor(value, distance));

        }

        /**
         * Offers the neighbor if there is space or it's closer than the
         * furthest neighbor.
         *
         * @param value Value of the neighbor.
         * @param distance Distance to the key value.
         * @return True if added, false if not added.
         */
        public boolean offer(
            PairType value,
            double distance)
        {

            // If we're full, then see if we're closer than the furthest
            // neighbor.
            if (this.isFull())
            {
                if (distance < this.getFurthestNeighborDistance())
                {
                    this.priorityQueue.remove();
                }
            }

            // If we find there is space, then add the new neighbor.
            if (!this.isFull())
            {
                this.add(value, distance);
                return true;
            }
            else
            {
                return false;
            }

        }

        @Override
        public Iterator iterator()
        {
            return new NeighborhoodIterator();
        }

        @Override
        public int size()
        {
            return this.priorityQueue.size();
        }

        /**
         * Holds neighbor information used during the evaluate method and is put
         * into a priority queue.
         *
         * @param  Type of Vectorizable, the first values
         * @param  Type of data in the Pair, the second values
         * @param  Type of Pair to use in the KDTree.
         */
        protected class Neighbor>
            extends AbstractCloneableSerializable
            implements Comparable>
        {
            // Note: This class does not follow the get/set pattern in order to
            // make it as fast as possible, because it is used within the evaluate
            // method. Also, its a private internal class, so no one else should
            // use it.

            /**
             * Pair to store.
             */
            PairType pair;

            /**
             * Distance associated with this value.
             */
            private double distance;

            /**
             * Creates a new neighbor.
             *
             * @param value The value associated with the neighbor.
             * @param distance Distance associated with this value.
             */
            public Neighbor(
                final PairType value,
                final double distance)
            {
                this.pair = value;
                this.distance = distance;
            }

            public int compareTo(
                final Neighbor other)
            {
                // We reverse the comparison so that the item at the head of the
                // priority queue is the furthest neighbor
                return -Double.compare(this.distance, other.distance);
            }

            @Override
            @SuppressWarnings("unchecked")
            public boolean equals(
                Object obj)
            {

                if (obj == null)
                {
                    return false;
                }
                else if (obj instanceof Neighbor)
                {
                    return ((Neighbor) obj).pair.getFirst().convertToVector().equals(
                        this.pair.getFirst().convertToVector());
                }
                else
                {
                    return false;
                }

            }

            @Override
            public int hashCode()
            {
                return this.pair.getFirst().hashCode();
            }

        }

        /**
         * Iterator for the Neighborhood.
         */
        protected class NeighborhoodIterator
            implements Iterator
        {

            /**
             * PriorityQueue iterator.
             */
            Iterator> priorityQueueIterator;

            /**
             * Default constructor.
             */
            public NeighborhoodIterator()
            {
                this.priorityQueueIterator = priorityQueue.iterator();
            }

            public boolean hasNext()
            {
                return this.priorityQueueIterator.hasNext();
            }

            public PairType next()
            {
                Neighbor next
                    = this.priorityQueueIterator.next();
                return next.pair;
            }

            public void remove()
            {
                throw new UnsupportedOperationException("Not supported yet.");
            }

        }

    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy