All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.nearest.NearestNeighborKDTree Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                NearestNeighborKDTree.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright Aug 10, 2009, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
 * license for use of this work by or on behalf of the U.S. Government.
 * Export of this program may require a license from the United States
 * Government. See CopyrightHistory.txt for complete details.
 * 
 */

package gov.sandia.cognition.learning.algorithm.nearest;

import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.DivergenceFunction;
import gov.sandia.cognition.math.Metric;
import gov.sandia.cognition.math.geometry.KDTree;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Collection;

/**
 * A KDTree-based implementation of the nearest neighbor algorithm.  This
 * algorithm has a O(n log(n)) construction time and a O(log(n)) evaluate time.
 * @param  Type of Vectorizable data upon which we determine
 * similarity.
 * @param  Output of the evaluator, like Matrix, Double, String
 * @author Kevin R. Dixon
 * @since 3.0
 */
public class NearestNeighborKDTree
    extends AbstractNearestNeighbor
{

    /**
     * KDTree that holds the data to search for neighbors.
     */
    private KDTree> data;

    /**
     * Creates a new instance of {@code NearestNeighborKDTree}.
     */
    public NearestNeighborKDTree()
    {
        this(null, null);
    }

    /**
     * Creates a new instance of NearestNeighborKDTree
     *
     * @param data
     * Underlying data for the classifier
     * @param divergenceFunction Divergence function that determines how "far" two objects are apart
     */
    public NearestNeighborKDTree(
        KDTree> data,
        DivergenceFunction divergenceFunction )
    {
        super( divergenceFunction );
        this.setData(data);
    }

    @Override
    public NearestNeighborKDTree clone()
    {
        @SuppressWarnings("unchecked")
        NearestNeighborKDTree clone =
            (NearestNeighborKDTree) super.clone();
        clone.setData( ObjectUtil.cloneSafe( this.getData() ) );
        return clone;
    }

    /**
     * Setter for distanceFunction
     * @return
     * Distance metric that determines how "far" two objects are apart,
     * where lower values indicate two objects are more similar.
     */
    @SuppressWarnings("unchecked")
    @Override
    public Metric getDivergenceFunction()
    {
        return (Metric) super.getDivergenceFunction();
    }

    @Override
    @SuppressWarnings("unchecked")
    public void setDivergenceFunction(
        DivergenceFunction divergenceFunction)
    {
        this.setDivergenceFunction( (Metric) divergenceFunction );
    }

    /**
     * Sets the Metric to use.
     * @param divergenceFunction
     * Metric that determines closeness.
     */
    public void setDivergenceFunction(
        Metric divergenceFunction)
    {
        super.setDivergenceFunction(divergenceFunction);
    }

    /**
     * Getter for data
     * @return
     * KDTree that holds the data to search for neighbors.
     */
    public KDTree> getData()
    {
        return this.data;
    }

    /**
     * Setter for data
     * @param data
     * KDTree that holds the data to search for neighbors.
     */
    public void setData(
        KDTree> data)
    {
        this.data = data;
    }

    public OutputType evaluate(
        InputType input)
    {
        Collection> neighbors =
            this.getData().findNearest(input, 1, this.getDivergenceFunction());

        InputOutputPair pair = CollectionUtil.getFirst(neighbors);
        if( pair != null )
        {
            return pair.getOutput();
        }
        else
        {
            return null;
        }
        
    }

    /**
     * This is a BatchLearner interface for creating a new NearestNeighbor
     * from a given dataset, simply a pass-through to the constructor of
     * NearestNeighbor
     * @param  Type of data upon which the NearestNeighbor operates,
     * something like Vector, Double, or String
     * @param  Output of the evaluator, like Matrix, Double, String
     */
    public static class Learner
        extends NearestNeighborKDTree
        implements SupervisedBatchLearner>
    {

        /**
         * Default constructor.
         */
        public Learner()
        {
            this( null );
        }

        /**
         * Creates a new instance of Learner
         * @param divergenceFunction
         * Divergence function that determines how "far" two objects are apart,
         * where lower values indicate two objects are more similar
         */
        public Learner(
            Metric divergenceFunction )
        {
            super( null, divergenceFunction );
        }

        /**
         * Creates a new NearestNeighbor from a Collection of InputType.
         * We build a balanced KDTree with the data, which is an O(n log(n))
         * operator for n data points.
         * @param data Dataset from which to create a new NearestNeighbor
         * @return
         * NearestNeighbor based on the given dataset with a balanced
         * KDTree.
         */
        public NearestNeighborKDTree learn(
            Collection> data )
        {
            @SuppressWarnings("unchecked")
            NearestNeighborKDTree clone = this.clone();
            KDTree> tree =
                KDTree.createBalanced(data);
            clone.setData( tree );
            return clone;
        }

    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy