gov.sandia.cognition.learning.performance.categorization.DefaultBinaryConfusionMatrix Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: ConfusionMatrix.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright July 18, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.performance.categorization;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.TargetEstimatePair;
import gov.sandia.cognition.learning.performance.AbstractSupervisedPerformanceEvaluator;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.Pair;
import gov.sandia.cognition.util.Summarizer;
import java.util.Collection;
import java.util.Collections;
import java.util.Set;
/**
* A default implementation of the {@code BinaryConfusionMatrix}. It stores the
* four entries in the confusion matrix: true positives, false positives,
* true negatives, and false negatives.
*
* @author Justin Basilico
* @author Kevin R. Dixon
* @since 3.1
*/
public class DefaultBinaryConfusionMatrix
extends AbstractBinaryConfusionMatrix
{
/** Number of true negatives. The (false, false) entry. */
protected double trueNegativesCount;
/** Number of false positives. The (false, true) entry. */
protected double falsePositivesCount;
/** Number of false negatives. The (true, false) entry. */
protected double falseNegativesCount;
/** Number of true positives. The (true, true) entry. */
protected double truePositivesCount;
/**
* Creates a new, empty {@code DefaultBinaryConfusionMatrix}.
*/
public DefaultBinaryConfusionMatrix()
{
super();
this.setTruePositivesCount(0.0);
this.setFalsePositivesCount(0.0);
this.setTrueNegativesCount(0.0);
this.setFalseNegativesCount(0.0);
}
@Override
public DefaultBinaryConfusionMatrix clone()
{
return (DefaultBinaryConfusionMatrix) super.clone();
}
@Override
public void add(
final Boolean target,
final Boolean estimate,
final double value)
{
if (target)
{
if (estimate)
{
this.truePositivesCount += value;
}
else
{
this.falseNegativesCount += value;
}
}
else
{
if (estimate)
{
this.falsePositivesCount += value;
}
else
{
this.trueNegativesCount += value;
}
}
}
@Override
public void clear()
{
this.setTruePositivesCount(0.0);
this.setFalseNegativesCount(0.0);
this.setTrueNegativesCount(0.0);
this.setFalsePositivesCount(0.0);
}
@Override
public double getTruePositivesCount()
{
return this.truePositivesCount;
}
/**
* Sets the number of true positives in the matrix.
*
* @param truePositivesCount
* The number of true positives. Cannot be negative.
*/
public void setTruePositivesCount(
final double truePositivesCount)
{
ArgumentChecker.assertIsNonNegative(
"truePositivesCount", truePositivesCount);
this.truePositivesCount = truePositivesCount;
}
@Override
public double getFalsePositivesCount()
{
return this.falsePositivesCount;
}
/**
* Sets the number of false positives in the matrix.
*
* @param falsePositivesCount
* The number of false positives. Cannot be negative.
*/
public void setFalsePositivesCount(
final double falsePositivesCount)
{
ArgumentChecker.assertIsNonNegative(
"falsePositivesCount", falsePositivesCount);
this.falsePositivesCount = falsePositivesCount;
}
@Override
public double getTrueNegativesCount()
{
return this.trueNegativesCount;
}
/**
* Sets the number of true negatives in the matrix.
*
* @param trueNegativesCount
* The number of true negatives. Cannot be negative.
*/
public void setTrueNegativesCount(
final double trueNegativesCount)
{
ArgumentChecker.assertIsNonNegative(
"trueNegativesCount", trueNegativesCount);
this.trueNegativesCount = trueNegativesCount;
}
@Override
public double getFalseNegativesCount()
{
return this.falseNegativesCount;
}
/**
* Sets the number of false negatives in the matrix.
*
* @param falseNegativesCount
* The number of false negatives. Cannot be negative.
*/
public void setFalseNegativesCount(
final double falseNegativesCount)
{
ArgumentChecker.assertIsNonNegative(
"falseNegativesCount", falseNegativesCount);
this.falseNegativesCount = falseNegativesCount;
}
/**
* Gets the rate of negative targets correctly classified. Computed as:
* TrueNegatives / TotalNegatives = TrueNegatives / (TrueNegatives + FalsePositives)
*
* @return
* True negative rate, which is in [0,1].
*/
public double getTrueNegativesRate()
{
double denom = this.getTrueNegativesCount() + this.getFalsePositivesCount();
double retval;
if (denom <= 0.0)
{
retval = 1.0;
}
else
{
retval = this.getTrueNegativesCount() / denom;
}
return retval;
}
/**
* Gets the rate of target trues that were correctly classified as true.
* Computed as:
* TruePositives / TotalPositives = TruePositives / (TruePositives + FalseNegatives)
*
* @return
* True positives rate, which is in [0,1].
*/
public double getTruePositivesRate()
{
double denom = this.getTruePositivesCount() + this.getFalseNegativesCount();
double retval;
if (denom <= 0.0)
{
retval = 1.0;
}
else
{
retval = this.getTruePositivesCount() / denom;
}
return retval;
}
/**
* Gets the rate of false targets incorrectly classified as true. Computed
* as:
* FalsePositives / TotalNegatives = FalsePositives / (TrueNegatives + FalsePositives)
*
* @return
* Rate of false positives, which is in [0,1]. Equal to
* 1 - trueNegativeRate.
*/
public double getFalsePositivesRate()
{
return 1.0 - this.getTrueNegativesRate();
}
/**
* Gets the rate of true targets incorrectly classified as false. Computed
* as:
* FalseNegatives / TotalPositives = FalseNegatives / (TruePositives + FalseNegatives)
*
* @return
* Rate of false negatives, which is in [0,1]. Equal to
* 1 - truePositiveRate
*/
public double getFalseNegativesRate()
{
return 1.0 - this.getTruePositivesRate();
}
@Override
public String toString()
{
return
"True Negatives: " + this.getTrueNegativesCount()
+ ", " + "False Positives: " + this.getFalsePositivesCount()
+ ", " + "False Negatives: " + this.getFalseNegativesCount()
+ ", " + "True Positives: " + this.getTruePositivesCount();
}
/**
* Takes a general confusion matrix and creates a binary form of it using
* true category. All other categories are considered false.
*
* @param
* The true category type.
* @param matrix
* The general confusion matrix to binarize.
* @param trueCategory
* The value of the true category for the binary confusion matrix.
* @return
* A new binary confusion matrix.
*/
public static DefaultBinaryConfusionMatrix binarizeOnTrueCategory(
final ConfusionMatrix matrix,
final CategoryType trueCategory)
{
return binarizeOnTrueSet(matrix, Collections.singleton(trueCategory));
}
/**
* Takes a general confusion matrix and creates a binary form of it using
* the given set of true categories. All other categories are considered
* false.
*
* @param
* The true category type.
* @param matrix
* The general confusion matrix to binarize.
* @param trueSet
* The set of categories in the true set for the binary confusion
* matrix.
* @return
* A new binary confusion matrix.
*/
public static DefaultBinaryConfusionMatrix binarizeOnTrueSet(
final ConfusionMatrix matrix,
final Set super CategoryType> trueSet)
{
final DefaultBinaryConfusionMatrix result =
new DefaultBinaryConfusionMatrix();
for (CategoryType actual : matrix.getActualCategories())
{
final boolean actualBinary = trueSet.contains(actual);
for (CategoryType predicted : matrix.getPredictedCategories(actual))
{
final boolean predictedBinary = trueSet.contains(predicted);
result.add(actualBinary, predictedBinary,
matrix.getCount(actual, predicted));
}
}
return result;
}
/**
* Creates a new {@code DefaultConfusionMatrix} from the given
* actual-predicted pairs.
*
* @param pairs
* The actual-category pairs.
* @return
* A new confusion matrix populated from the given actual-category
* pairs.
*/
public static DefaultBinaryConfusionMatrix createFromActualPredictedPairs(
final Collection extends Pair extends Boolean, ? extends Boolean>> pairs)
{
final DefaultBinaryConfusionMatrix result =
new DefaultBinaryConfusionMatrix();
for (Pair extends Boolean, ? extends Boolean> pair
: pairs)
{
result.add(pair.getFirst(), pair.getSecond());
}
return result;
}
/**
* Creates a new {@code DefaultBinaryConfusionMatrix} from the given
* target-estimate pairs.
*
* @param input
* The target-estimate pairs.
* @return
* A new confusion matrix populated from the given target-estimate
* pairs.
*/
public static DefaultBinaryConfusionMatrix create(
final Iterable extends TargetEstimatePair extends Boolean, ? extends Boolean>> input)
{
return DefaultBinaryConfusionMatrix.create(input, true);
}
/**
* Creates a new {@code DefaultBinaryConfusionMatrix} from the given
* target-estimate pairs.
*
* @param input
* The target-estimate pairs.
* @param weightIfAvailable
* True uses the weight of each item (or 1.0 if there is none); false
* means 1.0 is used regardless of weight.
* @return
* A new confusion matrix populated from the given target-estimate
* pairs.
*/
public static DefaultBinaryConfusionMatrix create(
final Iterable extends TargetEstimatePair extends Boolean, ? extends Boolean>> input,
final boolean weightIfAvailable )
{
double falsePositives = 0;
double falseNegatives = 0;
double truePositives = 0;
double trueNegatives = 0;
for (TargetEstimatePair extends Boolean, ? extends Boolean> pair
: input)
{
double increment;
if (weightIfAvailable)
{
increment = DatasetUtil.getWeight(pair);
}
else
{
increment = 1.0;
}
boolean target = pair.getTarget();
boolean estimate = pair.getEstimate();
if (target == true)
{
if (estimate == true)
{
truePositives += increment;
}
else
{
falseNegatives += increment;
}
}
else
{
if (estimate == true)
{
falsePositives += increment;
}
else
{
trueNegatives += increment;
}
}
}
final DefaultBinaryConfusionMatrix result =
new DefaultBinaryConfusionMatrix();
result.setFalsePositivesCount(falsePositives);
result.setFalseNegativesCount(falseNegatives);
result.setTruePositivesCount(truePositives);
result.setTrueNegativesCount(trueNegatives);
return result;
}
/**
* An implementation of the {@code SupervisedPerformanceEvaluator} interface
* for creating a {@code DefaultBinaryConfusionMatrix}.
*
* @param
* Input type for the evaluator to compute the confusion over.
*/
public static class PerformanceEvaluator
extends AbstractSupervisedPerformanceEvaluator
{
/**
* Creates a new {@code PerformanceEvaluator}.
*/
public PerformanceEvaluator()
{
super();
}
public DefaultBinaryConfusionMatrix evaluatePerformance(
final Collection extends TargetEstimatePair extends Boolean, ? extends Boolean>> data )
{
return DefaultBinaryConfusionMatrix.create(data);
}
}
/**
* A confusion matrix summarizer that summarizes actual-predicted pairs.
*/
public static class ActualPredictedPairSummarizer
extends AbstractCloneableSerializable
implements Summarizer, DefaultBinaryConfusionMatrix>
{
/**
* Creates a new {@code CombineSummarizer}.
*/
public ActualPredictedPairSummarizer()
{
super();
}
@Override
public DefaultBinaryConfusionMatrix summarize(
final Collection extends Pair extends Boolean, ? extends Boolean>> data)
{
return createFromActualPredictedPairs(data);
}
}
/**
* A confusion matrix summarizer that adds together confusion matrices.
*/
public static class CombineSummarizer
extends AbstractCloneableSerializable
implements Summarizer, DefaultBinaryConfusionMatrix>
{
/**
* Creates a new {@code CombineSummarizer}.
*/
public CombineSummarizer()
{
super();
}
@Override
public DefaultBinaryConfusionMatrix summarize(
final Collection extends ConfusionMatrix> data)
{
final DefaultBinaryConfusionMatrix result =
new DefaultBinaryConfusionMatrix();
for (ConfusionMatrix item : data)
{
result.addAll(item);
}
return result;
}
}
}