gov.sandia.cognition.learning.performance.categorization.DefaultConfusionMatrix Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: DefaultConfusionMatrix.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry Learning Core
*
* Copyright January 11, 2011, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
*/
package gov.sandia.cognition.learning.performance.categorization;
import gov.sandia.cognition.learning.data.TargetEstimatePair;
import gov.sandia.cognition.math.MutableDouble;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.Pair;
import gov.sandia.cognition.util.Summarizer;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
/**
* A default implementation of the {@code ConfusionMatrix} interface. It is
* backed by a two-level map storing the category object counts, making a
* sparse representation.
*
* @param
* The type of the category object over the confusion matrix.
* @author Justin Basilico
* @since 3.1
*/
public class DefaultConfusionMatrix
extends AbstractConfusionMatrix
{
/** The backing map of confusion matrix entries. The first key is the
* actual category and the second is the predicted category. */
protected Map> confusions;
/**
* Creates a new, empty {@code DefaultConfusionMatrix}.
*/
public DefaultConfusionMatrix()
{
super();
this.confusions = new LinkedHashMap>();
}
/**
* Creates a copy of a given confusion matrix.
*
* @param other
* The other confusion matrix to copy.
*/
public DefaultConfusionMatrix(
final ConfusionMatrix extends CategoryType> other)
{
this();
this.addAll(other);
}
@Override
public DefaultConfusionMatrix clone()
{
@SuppressWarnings("unchecked")
final DefaultConfusionMatrix clone =
(DefaultConfusionMatrix) super.clone();
if (this.confusions != null)
{
clone.confusions = new LinkedHashMap>(
this.confusions.size());
for (Map.Entry> outerEntry
: this.confusions.entrySet())
{
final LinkedHashMap categoryMap =
new LinkedHashMap(
outerEntry.getValue().size());
clone.confusions.put(outerEntry.getKey(), categoryMap);
for (Map.Entry innerEntry
: outerEntry.getValue().entrySet())
{
categoryMap.put(innerEntry.getKey(),
innerEntry.getValue().clone());
}
}
}
return clone;
}
@Override
public void add(
final CategoryType target,
final CategoryType estimate,
final double value)
{
Map subMap = confusions.get(target);
if (subMap == null)
{
subMap = new HashMap();
this.confusions.put(target, subMap);
}
MutableDouble entry = subMap.get(estimate);
if (entry == null)
{
entry = new MutableDouble(value);
subMap.put(estimate, entry);
}
else
{
entry.value += value;
}
}
@Override
public double getCount(
final CategoryType target,
final CategoryType estimate)
{
Map subMap = confusions.get(target);
if (subMap == null)
{
return 0.0;
}
else
{
MutableDouble entry = subMap.get(estimate);
if (entry == null)
{
return 0.0;
}
else
{
return entry.getValue();
}
}
}
@Override
public double getActualCount(
final CategoryType target)
{
Map subMap = confusions.get(target);
if (subMap == null)
{
return 0.0;
}
double result = 0.0;
for (MutableDouble value : subMap.values())
{
result += value.getValue();
}
return result;
}
@Override
public void clear()
{
this.confusions.clear();
}
@Override
public Set getCategories()
{
final LinkedHashSet result =
new LinkedHashSet();
result.addAll(this.getActualCategories());
result.addAll(this.getPredictedCategories());
return result;
}
@Override
public Set getActualCategories()
{
return this.confusions.keySet();
}
@Override
public Set getPredictedCategories()
{
final LinkedHashSet estimates = new LinkedHashSet(
this.confusions.size());
for (Map estimateCounts
: this.confusions.values())
{
estimates.addAll(estimateCounts.keySet());
}
return estimates;
}
@Override
public Set getPredictedCategories(
final CategoryType target)
{
Map subMap = confusions.get(target);
if (subMap == null)
{
return Collections.emptySet();
}
else
{
return subMap.keySet();
}
}
@Override
public String toString()
{
return this.confusions.toString();
}
/**
* Creates a new {@code DefaultConfusionMatrix} from the given
* actual-predicted pairs.
*
* @param
* The category type.
* @param pairs
* The actual-category pairs.
* @return
* A new confusion matrix populated from the given actual-category
* pairs.
*/
public static DefaultConfusionMatrix createUnweighted(
final Collection extends TargetEstimatePair extends CategoryType, ? extends CategoryType>> pairs)
{
final DefaultConfusionMatrix result =
new DefaultConfusionMatrix();
for (TargetEstimatePair extends CategoryType, ? extends CategoryType> item
: pairs)
{
result.add(item.getTarget(), item.getEstimate());
}
return result;
}
/**
* Creates a new {@code DefaultConfusionMatrix} from the given
* actual-predicted pairs.
*
* @param
* The category type.
* @param pairs
* The actual-category pairs.
* @return
* A new confusion matrix populated from the given actual-category
* pairs.
*/
public static DefaultConfusionMatrix createFromActualPredictedPairs(
final Collection extends Pair extends CategoryType, ? extends CategoryType>> pairs)
{
final DefaultConfusionMatrix result =
new DefaultConfusionMatrix();
for (Pair extends CategoryType, ? extends CategoryType> pair
: pairs)
{
result.add(pair.getFirst(), pair.getSecond());
}
return result;
}
/**
* A confusion matrix summarizer that summarizes actual-predicted pairs.
*
* @param
* The type of category of the summarizer.
*/
public static class ActualPredictedPairSummarizer
extends AbstractCloneableSerializable
implements Summarizer, DefaultConfusionMatrix>
{
/**
* Creates a new {@code CombineSummarizer}.
*/
public ActualPredictedPairSummarizer()
{
super();
}
@Override
public DefaultConfusionMatrix summarize(
final Collection extends Pair extends CategoryType, ? extends CategoryType>> data)
{
return createFromActualPredictedPairs(data);
}
}
/**
* A confusion matrix summarizer that adds together confusion matrices.
*
* @param
* The type of category of the summarizer.
*/
public static class CombineSummarizer
extends AbstractCloneableSerializable
implements Summarizer, DefaultConfusionMatrix>
{
/**
* Creates a new {@code CombineSummarizer}.
*/
public CombineSummarizer()
{
super();
}
@Override
public DefaultConfusionMatrix summarize(
final Collection extends ConfusionMatrix> data)
{
final DefaultConfusionMatrix result =
new DefaultConfusionMatrix();
for (ConfusionMatrix item : data)
{
result.addAll(item);
}
return result;
}
}
/**
* A factory for default confusion matrices.
*
* @param
* The type of category that the confusion is computed over.
*/
public static class Factory
extends AbstractCloneableSerializable
implements gov.sandia.cognition.factory.Factory>
{
/**
* Creates a new {@code Factory}.
*/
public Factory()
{
super();
}
@Override
public DefaultConfusionMatrix create()
{
return new DefaultConfusionMatrix();
}
}
}