All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.nearest.KNearestNeighborKDTree Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                KNearestNeighborKDTree.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright Aug 4, 2009, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
 * license for use of this work by or on behalf of the U.S. Government.
 * Export of this program may require a license from the United States
 * Government. See CopyrightHistory.txt for complete details.
 * 
 */

package gov.sandia.cognition.learning.algorithm.nearest;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.math.DivergenceFunction;
import gov.sandia.cognition.math.geometry.KDTree;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.distance.EuclideanDistanceMetric;
import gov.sandia.cognition.math.Metric;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Pair;
import gov.sandia.cognition.util.Summarizer;
import java.util.ArrayList;
import java.util.Collection;

/**
 * A KDTree-based implementation of the k-nearest neighbor algorithm.  This
 * algorithm has a O(n log(n)) construction time and a O(log(n)) evaluate time.
 * @param  Type of Vectorizable data upon which we determine
 * similarity.
 * @param  Output of the evaluator, like Matrix, Double, String
 * @see gov.sandia.cognition.math.geometry.KDTree
 * @author Kevin R. Dixon
 * @since 3.0
 */
@PublicationReference(
    author="Wikipedia",
    title="k-nearest neighbor algorithm",
    type=PublicationType.WebPage,
    year=2008,
    url="http://en.wikipedia.org/wiki/K-nearest_neighbor_algorithm"
)
public class KNearestNeighborKDTree
    extends AbstractKNearestNeighbor
    implements Evaluator
{

    /**
     * KDTree that holds the data to search for neighbors.
     */
    private KDTree> data;

    /** 
     * Creates a new instance of KNearestNeighborKDTree 
     */
    public KNearestNeighborKDTree()
    {
        this( DEFAULT_K, null, null, null );
    }

    /**
     * Creates a new instance of KNearestNeighborKDTree
     * @param k
     * Number of neighbors to consider, must be greater than zero
     * @param data
     * KDTree that holds the data to search for neighbors.
     * @param distanceFunction
     * Distance metric that determines how "far" two objects are apart,
     * where lower values indicate two objects are more similar
     * @param averager
     * KDTree that holds the data to search for neighbors.
     */
    public KNearestNeighborKDTree(
        int k,
        KDTree> data,
        Metric distanceFunction,
        Summarizer averager )
    {
        super( k, distanceFunction, averager );
        this.setData(data);
    }

    @Override
    public KNearestNeighborKDTree clone()
    {
        KNearestNeighborKDTree clone =
            (KNearestNeighborKDTree) super.clone();
        clone.setData( ObjectUtil.cloneSafe( this.getData() ) );
        return clone;
    }

    /**
     * Setter for distanceFunction
     * @return
     * Distance metric that determines how "far" two objects are apart,
     * where lower values indicate two objects are more similar.
     */
    @SuppressWarnings("unchecked")
    @Override
    public Metric getDivergenceFunction()
    {
        return (Metric) super.getDivergenceFunction();
    }

    @Override
    @SuppressWarnings("unchecked")
    public void setDivergenceFunction(
        DivergenceFunction divergenceFunction)
    {
        this.setDivergenceFunction( (Metric) divergenceFunction );
    }

    /**
     * Sets the Metric to use.
     * @param divergenceFunction
     * Metric that determines closeness.
     */
    public void setDivergenceFunction(
        Metric divergenceFunction)
    {
        super.setDivergenceFunction(divergenceFunction);
    }

    /**
     * Getter for data
     * @return 
     * KDTree that holds the data to search for neighbors.
     */
    public KDTree> getData()
    {
        return this.data;
    }

    /**
     * Setter for data
     * @param data
     * KDTree that holds the data to search for neighbors.
     */
    public void setData(
        KDTree> data)
    {
        this.data = data;
    }

    @Override
    protected Collection computeNeighborhood(
        InputType key)
    {
        Collection> neighbors =
            this.getData().findNearest(key, this.getK(), this.getDivergenceFunction());
        ArrayList outputs =
            new ArrayList( neighbors.size() );
        for( Pair neighbor : neighbors )
        {
            outputs.add( neighbor.getSecond() );
        }

        return outputs;
    }

    /**
     * Rebalances the internal KDTree to make the search more efficient.  This
     * is an O(n log(n)) operation with n samples.
     */
    public void rebalance()
    {
        this.setData( this.getData().reblanace() );
    }

    /**
     * This is a BatchLearner interface for creating a new KNearestNeighbor
     * from a given dataset, simply a pass-through to the constructor of
     * KNearestNeighbor
     * @param  Type of data upon which the KNearestNeighbor operates,
     * something like Vector, Double, or String
     * @param  Output of the evaluator, like Matrix, Double, String
     */
    public static class Learner
        extends KNearestNeighborKDTree
        implements SupervisedBatchLearner>
    {

        /**
         * Default constructor.
         */
        public Learner()
        {
            this( null );
        }

        /**
         * Creates a new instance of Learner.
         * @param averager
         * Creates a single object from a collection of data.
         */
        public Learner(
            Summarizer averager )
        {
            this( DEFAULT_K, EuclideanDistanceMetric.INSTANCE, averager );
        }

        /**
         * Creates a new instance of Learner
         * @param k
         * Number of neighbors to consider, must be greater than zero
         * @param divergenceFunction
         * Divergence function that determines how "far" two objects are apart,
         * where lower values indicate two objects are more similar
         * @param averager
         * Creates a single object from a collection of data
         */
        public Learner(
            int k,
            Metric divergenceFunction,
            Summarizer averager )
        {
            super( k, null, divergenceFunction, averager );
        }

        /**
         * Creates a new KNearestNeighbor from a Collection of InputType.
         * We build a balanced KDTree with the data, which is an O(n log(n))
         * operator for n data points.
         * @param data Dataset from which to create a new KNearestNeighbor
         * @return
         * KNearestNeighbor based on the given dataset with a balanced
         * KDTree.
         */
        public KNearestNeighborKDTree learn(
            Collection> data )
        {
            @SuppressWarnings("unchecked")
            KNearestNeighborKDTree clone = this.clone();
            KDTree> tree =
                KDTree.createBalanced(data);
            clone.setData( tree );
            return clone;
        }

    }
    
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy