gov.sandia.cognition.learning.function.categorization.BinaryVersusCategorizer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: BinaryVersusCategorizer.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright April 08, 2009, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.function.categorization;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.AbstractBatchLearnerContainer;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.DefaultWeightedValueDiscriminant;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Pair;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
/**
* An adapter that allows binary categorizers to be adapted for multi-category
* problems by applying a binary categorizer to each pair of categories.
*
* @param
* The type of the input to categorize.
* @param
* The type of the output categories.
* @author Justin Basilico
* @since 3.0
*/
public class BinaryVersusCategorizer
extends AbstractDiscriminantCategorizer
{
/** Maps false-true category pairs . */
protected Map, Evaluator super InputType, Boolean>>
categoryPairsToEvaluatorMap;
/**
* Creates a new {@code BinaryVersusCategorizer}.
*/
public BinaryVersusCategorizer()
{
this(new LinkedHashSet(),
new LinkedHashMap, Evaluator super InputType, Boolean>>());
}
/**
* Creates a new {@code BinaryVersusCategorizer} with the given
* categories and an empty set of evaluators.
*
* @param categories
* The possible output categories.
*/
public BinaryVersusCategorizer(
final Set categories)
{
this(categories,
new LinkedHashMap, Evaluator super InputType, Boolean>>(
(categories.size() * categories.size() / 2)));
}
/**
* Creates a new {@code BinaryVersusCategorizer}.
*
* @param categories
* The possible output categories.
* @param categoryPairsToEvaluatorMap
* The mapping of category pairs to their binary categorizer.
*/
public BinaryVersusCategorizer(
final Set categories,
final Map, Evaluator super InputType, Boolean>> categoryPairsToEvaluatorMap)
{
super(categories);
this.setCategoryPairsToEvaluatorMap(categoryPairsToEvaluatorMap);
}
@Override
public BinaryVersusCategorizer clone()
{
BinaryVersusCategorizer result = (BinaryVersusCategorizer)
super.clone();
result.categoryPairsToEvaluatorMap =
new LinkedHashMap, Evaluator super InputType, Boolean>>(
this.categoryPairsToEvaluatorMap.size());
for (Map.Entry, Evaluator super InputType, Boolean>> entry
: this.categoryPairsToEvaluatorMap.entrySet())
{
result.categoryPairsToEvaluatorMap.put(
ObjectUtil.cloneSmart(entry.getKey()),
ObjectUtil.cloneSmart(entry.getValue()));
}
return result;
}
@Override
public DefaultWeightedValueDiscriminant evaluateWithDiscriminant(
final InputType input)
{
final int categoryCount = this.categories.size();
if (categoryCount <= 0)
{
// No categories.
return null;
}
else if (categoryCount == 1)
{
// There is only one category.
return DefaultWeightedValueDiscriminant.create(
CollectionUtil.getFirst(this.categories), 1.0);
}
// We are going to count the number of votes for each category.
final DefaultDataDistribution results =
new DefaultDataDistribution(categoryCount);
// Go through all the pairs of evaluators.
for (Map.Entry, Evaluator super InputType, Boolean>> entry
: this.categoryPairsToEvaluatorMap.entrySet())
{
// Evaluate the binary categorizer for the two classes on the given
// input.
final Boolean category = entry.getValue().evaluate(input);
if (category == null)
{
// Null values do not vote.
}
else if (!category)
{
// This belongs to the false (first) category.
results.increment(entry.getKey().getFirst());
}
else
{
// This belongs to the true (second) category.
results.increment(entry.getKey().getSecond());
}
}
// The one with the most votes is the category we use.
final CategoryType bestCategory = results.getMaxValueKey();
final double bestFraction = results.getFraction(bestCategory);
return DefaultWeightedValueDiscriminant.create(
bestCategory, bestFraction);
}
/**
* Gets the mapping of false-true category pairs to the binary categorizer
* that distinguishes them.
*
* @return
* The mapping of category pairs to their binary categorizer.
*/
public Map, Evaluator super InputType, Boolean>> getCategoryPairsToEvaluatorMap()
{
return this.categoryPairsToEvaluatorMap;
}
/**
* Sets the mapping of false-true category pairs to the binary categorizer
* that distinguishes them.
*
* @param categoryPairsToEvaluatorMap
* The mapping of category pairs to their binary categorizer.
*/
public void setCategoryPairsToEvaluatorMap(
final Map, Evaluator super InputType, Boolean>> categoryPairsToEvaluatorMap)
{
this.categoryPairsToEvaluatorMap = categoryPairsToEvaluatorMap;
}
/**
* A learner for the {@code BinaryVersusCategorizer}. It learns a
* binary categorizer for each pair of categories.
*
* @param
* The input to learn from and the input to the learned categorizer.
* @param
* The type of categories to learn from.
*/
public static class Learner
extends AbstractBatchLearnerContainer>, ? extends Evaluator super InputType, Boolean>>>
implements SupervisedBatchLearner>
{
/**
* Creates a new {@code BinaryVersusCategorizer.Learner} with no
* initial binary categorizer learner.
*/
public Learner()
{
this(null);
}
/**
* Creates a new {@code BinaryVersusCategorizer.Learner} with an
* binary categorizer learner to learn category versus category.
*
* @param learner
* The binary categorizer learner to use to learn decision
* functions between categories.
*/
public Learner(
BatchLearner super Collection extends InputOutputPair extends InputType, Boolean>>, ? extends Evaluator super InputType, Boolean>> learner)
{
super(learner);
}
@Override
public BinaryVersusCategorizer learn(
final Collection extends InputOutputPair extends InputType, CategoryType>> data)
{
// Find the categories. We're going to look at pairs of categories
// so we also make a list version of the set.
final Set categories =
DatasetUtil.findUniqueOutputs(data);
final int categoryCount = categories.size();
final ArrayList categoriesList =
new ArrayList(categories);
// Create the object to hold the result.
final BinaryVersusCategorizer result =
new BinaryVersusCategorizer(categories);
for (int i = 0; i < categoryCount; i++)
{
// This is the false category.
final CategoryType falseCategory = categoriesList.get(i);
for (int j = i + 1; j < categoryCount; j++)
{
// This is the true category.
final CategoryType trueCategory = categoriesList.get(j);
final ArrayList> versusData =
new ArrayList>();
for (InputOutputPair extends InputType, CategoryType> example
: data)
{
// The category of the label.
final CategoryType category = example.getOutput();
if (falseCategory.equals(category))
{
// This is an example belonging to the false
// category.
versusData.add(new DefaultInputOutputPair(
example.getInput(), false));
}
else if (trueCategory.equals(category))
{
// This is an example belonging to the true
// category.
versusData.add(new DefaultInputOutputPair(
example.getInput(), true));
}
// else - The example did not belong to either category.
}
// Learn on the binary data.
final Evaluator super InputType, Boolean> learned =
this.getLearner().learn(versusData);
// Add the learned categorizer.
result.categoryPairsToEvaluatorMap.put( DefaultPair.create(
falseCategory, trueCategory ), learned );
}
}
// Returns the created adapter.
return result;
}
}
}