gov.sandia.cognition.learning.algorithm.nearest.NearestNeighborKDTree Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: NearestNeighborKDTree.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Aug 10, 2009, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.nearest;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.DivergenceFunction;
import gov.sandia.cognition.math.Metric;
import gov.sandia.cognition.math.geometry.KDTree;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Collection;
/**
* A KDTree-based implementation of the nearest neighbor algorithm. This
* algorithm has a O(n log(n)) construction time and a O(log(n)) evaluate time.
* @param Type of Vectorizable data upon which we determine
* similarity.
* @param Output of the evaluator, like Matrix, Double, String
* @author Kevin R. Dixon
* @since 3.0
*/
public class NearestNeighborKDTree
extends AbstractNearestNeighbor
{
/**
* KDTree that holds the data to search for neighbors.
*/
private KDTree> data;
/**
* Creates a new instance of {@code NearestNeighborKDTree}.
*/
public NearestNeighborKDTree()
{
this(null, null);
}
/**
* Creates a new instance of NearestNeighborKDTree
*
* @param data
* Underlying data for the classifier
* @param divergenceFunction Divergence function that determines how "far" two objects are apart
*/
public NearestNeighborKDTree(
KDTree> data,
DivergenceFunction super InputType, ? super InputType> divergenceFunction )
{
super( divergenceFunction );
this.setData(data);
}
@Override
public NearestNeighborKDTree clone()
{
@SuppressWarnings("unchecked")
NearestNeighborKDTree clone =
(NearestNeighborKDTree) super.clone();
clone.setData( ObjectUtil.cloneSafe( this.getData() ) );
return clone;
}
/**
* Setter for distanceFunction
* @return
* Distance metric that determines how "far" two objects are apart,
* where lower values indicate two objects are more similar.
*/
@SuppressWarnings("unchecked")
@Override
public Metric super InputType> getDivergenceFunction()
{
return (Metric super InputType>) super.getDivergenceFunction();
}
@Override
@SuppressWarnings("unchecked")
public void setDivergenceFunction(
DivergenceFunction super InputType, ? super InputType> divergenceFunction)
{
this.setDivergenceFunction( (Metric super InputType>) divergenceFunction );
}
/**
* Sets the Metric to use.
* @param divergenceFunction
* Metric that determines closeness.
*/
public void setDivergenceFunction(
Metric super InputType> divergenceFunction)
{
super.setDivergenceFunction(divergenceFunction);
}
/**
* Getter for data
* @return
* KDTree that holds the data to search for neighbors.
*/
public KDTree> getData()
{
return this.data;
}
/**
* Setter for data
* @param data
* KDTree that holds the data to search for neighbors.
*/
public void setData(
KDTree> data)
{
this.data = data;
}
public OutputType evaluate(
InputType input)
{
Collection> neighbors =
this.getData().findNearest(input, 1, this.getDivergenceFunction());
InputOutputPair,OutputType> pair = CollectionUtil.getFirst(neighbors);
if( pair != null )
{
return pair.getOutput();
}
else
{
return null;
}
}
/**
* This is a BatchLearner interface for creating a new NearestNeighbor
* from a given dataset, simply a pass-through to the constructor of
* NearestNeighbor
* @param Type of data upon which the NearestNeighbor operates,
* something like Vector, Double, or String
* @param Output of the evaluator, like Matrix, Double, String
*/
public static class Learner
extends NearestNeighborKDTree
implements SupervisedBatchLearner>
{
/**
* Default constructor.
*/
public Learner()
{
this( null );
}
/**
* Creates a new instance of Learner
* @param divergenceFunction
* Divergence function that determines how "far" two objects are apart,
* where lower values indicate two objects are more similar
*/
public Learner(
Metric super Vectorizable> divergenceFunction )
{
super( null, divergenceFunction );
}
/**
* Creates a new NearestNeighbor from a Collection of InputType.
* We build a balanced KDTree with the data, which is an O(n log(n))
* operator for n data points.
* @param data Dataset from which to create a new NearestNeighbor
* @return
* NearestNeighbor based on the given dataset with a balanced
* KDTree.
*/
public NearestNeighborKDTree learn(
Collection extends InputOutputPair extends InputType,OutputType>> data )
{
@SuppressWarnings("unchecked")
NearestNeighborKDTree clone = this.clone();
KDTree> tree =
KDTree.createBalanced(data);
clone.setData( tree );
return clone;
}
}
}