gov.sandia.cognition.learning.function.categorization.WinnerTakeAllCategorizer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: WinnerTakeAllCategorizer.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright April 07, 2009, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.function.categorization;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.AbstractBatchLearnerContainer;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.data.DefaultWeightedValueDiscriminant;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorFactoryContainer;
import gov.sandia.cognition.math.matrix.Vectorizable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Set;
/**
* Adapts an evaluator that outputs a vector to be used as a categorizer.
* @param The type of the input the categorizer can use.
* @param The type of category output by the categorizer.
* @author Justin Basilico
* @since 3.0
*/
public class WinnerTakeAllCategorizer
extends AbstractDiscriminantCategorizer
{
/** The evaluator that outputs a vector to return. */
protected Evaluator super InputType, ? extends Vectorizable> evaluator;
/**
* Creates a new {@code WinnerTakesAllCategorizer}.
*/
public WinnerTakeAllCategorizer()
{
this(null, new LinkedHashSet());
}
/**
* Creates a new {@code WinnerTakesAllCategorizer}.
*
* @param evaluator
* The evaluator that this class adapts.
* @param categories
* The set of categories.
*/
public WinnerTakeAllCategorizer(
final Evaluator super InputType, ? extends Vectorizable> evaluator,
final Set categories)
{
super(categories);
this.setEvaluator(evaluator);
}
/**
* Evaluates the categorizer and returns the category along with a weight.
*
* @param input
* The input to evaluate.
* @return
* The output plus its weight.
*/
@Override
public DefaultWeightedValueDiscriminant evaluateWithDiscriminant(
final InputType input)
{
final Vector output = this.evaluator.evaluate(input).convertToVector();
return this.findBestCategory(output);
}
/**
* Finds the best category (and its output value) from the given vector
* of outputs from a vector evaluator. Ties are broken by taking the first
* category with the highest value.
*
* @param output
* The vector of outputs, one corresponding to each category.
* @return
* The category with the highest output value. Ties are broken by
* taking the first category with the highest value.
*/
public DefaultWeightedValueDiscriminant findBestCategory(
final Vector output)
{
output.assertDimensionalityEquals(this.categories.size());
// Loop througn and find the best category.
CategoryType best = null;
double bestValue = Double.NEGATIVE_INFINITY;
int index = 0;
for (CategoryType category : this.categories)
{
final double value = output.getElement(index);
if (best == null || value > bestValue)
{
best = category;
bestValue = value;
}
index++;
}
// Return the best category.
return new DefaultWeightedValueDiscriminant(best, bestValue);
}
/**
* Gets the wrapped evaluator.
*
* @return
* The wrapped evaluator.
*/
public Evaluator super InputType, ? extends Vectorizable> getEvaluator()
{
return this.evaluator;
}
/**
* Sets the wrapped evaluator.
*
* @param evaluator
* The wrapped evaluator.
*/
public void setEvaluator(
final Evaluator super InputType, ? extends Vectorizable> evaluator)
{
this.evaluator = evaluator;
}
@Override
public void setCategories(
final Set categories)
{
// Just making this publicly visiable.
super.setCategories(categories);
}
/**
* A learner for the adapter. Makes it so that learning algorithms that
* produce evaluators whose outputs are vectors to be used for
* categorization.
*
* @param
* The type of the input data for the learner.
* @param
* The type of the output categories.
*/
public static class Learner
extends AbstractBatchLearnerContainer>, ? extends Evaluator super InputType, ? extends Vectorizable>>>
implements SupervisedBatchLearner>,
VectorFactoryContainer
{
/** The vector factory used. */
protected VectorFactory> vectorFactory;
/**
* Creates a new learner adapter.
*/
public Learner()
{
this(null);
}
/**
* Creates a new learner adapter with the given internal learner.
*
* @param learner
* The learner to adapt.
*/
public Learner(
final BatchLearner super Collection extends InputOutputPair extends InputType, Vector>>, Evaluator super InputType, ? extends Vectorizable>> learner)
{
super(learner);
this.setVectorFactory(VectorFactory.getDefault());
}
public WinnerTakeAllCategorizer learn(
final Collection extends InputOutputPair extends InputType, CategoryType>> data)
{
// First figure out all of the categories.
final LinkedHashMap categoryIndices =
new LinkedHashMap();
for (InputOutputPair, CategoryType> example : data)
{
final CategoryType category = example.getOutput();
if (!categoryIndices.containsKey(category))
{
final int index = categoryIndices.size();
categoryIndices.put(category, index);
}
}
// Now convert the dataset to have the output be vectors.
final int categoryCount = categoryIndices.size();
final ArrayList> vectorData =
new ArrayList>(data.size());
for (InputOutputPair extends InputType, CategoryType> example
: data)
{
final CategoryType category = example.getOutput();
final int index = categoryIndices.get(category);
// TODO: Make use of the boolean encoding code from the data package at some point.
final Vector output = this.getVectorFactory().createVector(
categoryCount, -1.0);
output.setElement(index, +1.0);
vectorData.add(new DefaultInputOutputPair(
example.getInput(), output));
}
// Do the learning.
final Evaluator super InputType, ? extends Vectorizable> learned =
this.getLearner().learn(vectorData);
final LinkedHashSet categories =
new LinkedHashSet(categoryIndices.keySet());
return new WinnerTakeAllCategorizer(
learned, categories);
}
/**
* Gets the vector factory.
*
* @return
* The vector factory.
*/
public VectorFactory> getVectorFactory()
{
return this.vectorFactory;
}
/**
* Sets the vector factory.
*
* @param vectorFactory
* The vector factory.
*/
public void setVectorFactory(
final VectorFactory> vectorFactory)
{
this.vectorFactory = vectorFactory;
}
}
}