gov.sandia.cognition.learning.function.distance.DivergencesEvaluator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: ClusterDistanceEvaluator.java
* Authors: Justin Basilico
* Project: Cognitive Foundry Learning Core
*
* Copyright 2011 Cognitive Foundry. All rights reserved.
*/
package gov.sandia.cognition.learning.function.distance;
import gov.sandia.cognition.data.convert.vector.AbstractToVectorEncoder;
import gov.sandia.cognition.learning.algorithm.AbstractBatchLearnerContainer;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.math.DivergenceFunction;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorFactoryContainer;
import gov.sandia.cognition.math.matrix.VectorOutputEvaluator;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
/**
* Evaluates the divergence (distance) between an input and a list of values,
* storing the resulting divergence values in a vector. This can be used as a
* feature representation built from something like a clustering algorithm or
* from a set of prototype/basis elements.
*
* @param
* The type of input value that the class evaluates. It is the second
* parameter passed to the divergence function. Typically a type like
* a Vector.
* @param
* The type of value that the divergence is computed from. It is the
* first parameter passed to the divergence function. It is typically a
* type like Vector or CenteroidCluster.
* @author Justin Basilico
* @since 3.3.3
*/
public class DivergencesEvaluator
extends AbstractToVectorEncoder
implements VectorOutputEvaluator,
DivergenceFunctionContainer
{
/** The divergence function to apply between the data and the input. */
protected DivergenceFunction super ValueType, ? super InputType> divergenceFunction;
/** The data to evaluate the divergence from. */
protected Collection values;
/**
* Creates a new {@code DivergencesEvaluator} with a null divergence
* function and an empty set of values.
*/
public DivergencesEvaluator()
{
this(null, new ArrayList());
}
/**
* Creates a new {@code DivergencesEvaluator} with the given divergence
* and values.
*
* @param divergenceFunction
* The divergence function to use.
* @param values
* The values to calculate the divergence from.
*/
public DivergencesEvaluator(
final DivergenceFunction super ValueType, ? super InputType> divergenceFunction,
final Collection values)
{
this(divergenceFunction, values, VectorFactory.getDefault());
}
/**
* Creates a new {@code DivergencesEvaluator} with the given divergence
* and values.
*
* @param divergenceFunction
* The divergence function to use.
* @param values
* The values to calculate the divergence from.
* @param vectorFactory
* The vector factory to use.
*/
public DivergencesEvaluator(
final DivergenceFunction super ValueType, ? super InputType> divergenceFunction,
final Collection values,
final VectorFactory> vectorFactory)
{
super(vectorFactory);
this.setDivergenceFunction(divergenceFunction);
this.setValues(values);
}
@Override
public DivergencesEvaluator clone()
{
@SuppressWarnings("unchecked")
final DivergencesEvaluator clone = (DivergencesEvaluator)
super.clone();
clone.divergenceFunction = ObjectUtil.cloneSmart(this.divergenceFunction);
clone.values = ObjectUtil.cloneSmartElementsAsArrayList(this.values);
return clone;
}
@Override
public void encode(
final InputType input,
final Vector result,
final int startIndex)
{
// Go through the values and compute the divergence to each one.
int index = startIndex;
for (final ValueType cluster : this.getValues())
{
final double distance =
this.divergenceFunction.evaluate(cluster, input);
result.setElement(index, distance);
index++;
}
}
@Override
public int getOutputDimensionality()
{
return this.getValues().size();
}
@Override
public DivergenceFunction super ValueType, ? super InputType> getDivergenceFunction()
{
return this.divergenceFunction;
}
/**
* Sets the divergence function to use from the values to the inputs.
*
* @param divergenceFunction
* The divergence function to use.
*/
public void setDivergenceFunction(
final DivergenceFunction super ValueType, ? super InputType> divergenceFunction)
{
this.divergenceFunction = divergenceFunction;
}
/**
* Gets the values that the divergence is computed from using the
* divergence function to the input.
*
* @return
* The values that the distance is computed from.
*/
public Collection getValues()
{
return this.values;
}
/**
* Sets the values that the divergence is computed from using the
* divergence function to the input.
*
* @param values
* The values that the distance is computed from.
*/
public void setValues(
final Collection values)
{
this.values = values;
}
/**
* Convenience method for creation a {@code DivergeceEvaluator}.
*
* @param
* The type of input value that the class evaluates.
* @param
* The type of value that the divergence is computed from.
* @param divergenceFunction
* The divergence function to use.
* @param values
* The values to calculate the divergence from.
* @return
* A new evaluator.
*/
public static DivergencesEvaluator
create(
final DivergenceFunction super ValueType, ? super InputType> divergenceFunction,
final Collection values)
{
return new DivergencesEvaluator(
divergenceFunction, values);
}
/**
* A learner adapter for the {@code DivergencesEvaluator}. It calls a
* base learner and then wraps learned collection of values in an evaluator
* that uses the given divergence function.
*
* @param
* The data type for learning. Passed to the wrapped learner.
* @param
* The input type for the evaluator.
* @param
* The value type that is the output of learning and is used as the
* values in the learned evaluator.
*/
public static class Learner
extends AbstractBatchLearnerContainer>>
implements BatchLearner>,
DivergenceFunctionContainer,
VectorFactoryContainer
{
/** The divergence function to apply between the data and the input. */
protected DivergenceFunction super ValueType, ? super InputType> divergenceFunction;
/** The vector factory to use. */
protected VectorFactory> vectorFactory;
/**
* Creates a new {@code DivergenceFunction.Learner} with null base
* learner and divergence functions.
*/
public Learner()
{
this(null, null);
}
/**
* Creates a new {@code DivergenceFunction.Learner} with the given
* properties.
*
* @param learner
* The base learner to use.
* @param divergenceFunction
* The divergence function to use.
*/
public Learner(
final BatchLearner> learner,
final DivergenceFunction super ValueType, ? super InputType> divergenceFunction)
{
this(learner, divergenceFunction, VectorFactory.getDefault());
}
/**
* Creates a new {@code DivergenceFunction.Learner} with the given
* properties.
*
* @param learner
* The base learner to use.
* @param divergenceFunction
* The divergence function to use.
* @param vectorFactory
* The vector factory to use.
*/
public Learner(
final BatchLearner> learner,
final DivergenceFunction super ValueType, ? super InputType> divergenceFunction,
final VectorFactory> vectorFactory)
{
super(learner);
this.setDivergenceFunction(divergenceFunction);
this.setVectorFactory(vectorFactory);
}
@Override
public Learner clone()
{
@SuppressWarnings("unchecked")
final Learner clone = (Learner)
super.clone();
clone.divergenceFunction = ObjectUtil.cloneSmart(this.divergenceFunction);
return clone;
}
@Override
public DivergencesEvaluator learn(
final DataType data)
{
return new DivergencesEvaluator(
this.getDivergenceFunction(),
this.getLearner().learn(data),
this.getVectorFactory());
}
@Override
public DivergenceFunction super ValueType, ? super InputType> getDivergenceFunction()
{
return this.divergenceFunction;
}
/**
* Sets the divergence function to use from the values to the inputs.
*
* @param divergenceFunction
* The divergence function to use.
*/
public void setDivergenceFunction(
final DivergenceFunction super ValueType, ? super InputType> divergenceFunction)
{
this.divergenceFunction = divergenceFunction;
}
@Override
public VectorFactory extends Vector> getVectorFactory()
{
return this.vectorFactory;
}
/**
* Sets the vector factory to use.
*
* @param vectorFactory
* The vector factory to use.
*/
public void setVectorFactory(
final VectorFactory> vectorFactory)
{
this.vectorFactory = vectorFactory;
}
/**
* Convenience method for creating a
* {@code DivergencesEvaluator.Learner}.
*
* @param
* The data type for learning. Passed to the wrapped learner.
* @param
* The input type for the evaluator.
* @param
* The value type that is the output of learning and is used as the
* values in the learned evaluator.
* @param learner
* The base learner to use.
* @param divergenceFunction
* The divergence function to use.
* @return
* A new learner.
*/
public static Learner
create(
final BatchLearner> learner,
final DivergenceFunction super ValueType, ? super InputType> divergenceFunction)
{
return new Learner(
learner, divergenceFunction);
}
}
}