All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.performance.categorization.AbstractConfusionMatrix Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                AbstractConfusionMatrix.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry Learning Core
 * 
 * Copyright January 17, 2011, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive 
 * license for use of this work by or on behalf of the U.S. Government. Export 
 * of this program may require a license from the United States Government. 
 */

package gov.sandia.cognition.learning.performance.categorization;

import gov.sandia.cognition.util.AbstractCloneableSerializable;

/**
 * An abstract implementation of the {@code ConfusionMatrix} interface. Takes
 * care of a lot of the support functions in the interface so that
 * implementations can focus on the core data structures.
 *
 * @param   
 *      The type of category that the confusion matrix is under.
 * @author  Justin Basilico
 * @since   3.1
 */
public abstract class AbstractConfusionMatrix
    extends AbstractCloneableSerializable
    implements ConfusionMatrix
{

    /**
     * Creates a new {@code AbstractConfusionMatrix}.
     */
    public AbstractConfusionMatrix()
    {
        super();
    }

    @Override
    public void add(
        final CategoryType target,
        final CategoryType estimate)
    {
        this.add(target, estimate, 1.0);
    }

    @Override
    public  void addAll(
        final ConfusionMatrix other)
    {
        for (OtherType target : other.getActualCategories())
        {
            for (OtherType estimate : other.getPredictedCategories(target))
            {
                this.add(target, estimate, other.getCount(target, estimate));
            }
        }
    }
    
    @Override
    public boolean isEmpty()
    {
        return this.getTotalCount() <= 0.0;
    }

    @Override
    public double getTotalCount()
    {
        double result = 0.0;
        for (CategoryType target : this.getActualCategories())
        {
            result += this.getActualCount(target);
        }
        return result;
    }

    @Override
    public double getTotalCorrectCount()
    {
        double correct = 0.0;
        for (CategoryType category : this.getActualCategories())
        {
            correct += this.getCount(category, category);
        }
        return correct;
    }

    @Override
    public double getTotalIncorrectCount()
    {
        return this.getTotalCount() - this.getTotalCorrectCount();
    }

    @Override
    public double getActualCount(
        final CategoryType target)
    {
        double result = 0.0;
        for (CategoryType estimate : this.getPredictedCategories(target))
        {
            result += this.getCount(target, estimate);
        }
        return result;
    }

    @Override
    public double getPredictedCount(
        final CategoryType predicted)
    {
        double result = 0.0;
        for (CategoryType actual : this.getActualCategories())
        {
            result += this.getCount(actual, predicted);
        }
        return result;
    }
    
    @Override
    public double getAccuracy()
    {
        return this.getTotalCorrectCount() / this.getTotalCount();
    }

    @Override
    public double getCategoryAccuracy(
        final CategoryType category)
    {
        return this.getCount(category, category) / this.getActualCount(category);
    }

    @Override
    public double getAverageCategoryAccuracy()
    {
        double sum = 0.0;
        int categoryCount = 0;
        for (CategoryType actual : this.getActualCategories())
        {
            if (this.getActualCount(actual) > 0)
            {
                sum += this.getCategoryAccuracy(actual);
                categoryCount++;
            }
        }
        return sum / categoryCount;
    }

    @Override
    public double getErrorRate()
    {
        return 1.0 - this.getAccuracy();
    }

    @Override
    public double getCategoryErrorRate(
        final CategoryType category)
    {
        return 1.0 - this.getCategoryAccuracy(category);
    }

    @Override
    public double getAverageCategoryErrorRate()
    {
        double sum = 0.0;
        int categoryCount = 0;
        for (CategoryType actual : this.getActualCategories())
        {
            if (this.getActualCount(actual) > 0)
            {
                sum += this.getCategoryErrorRate(actual);
                categoryCount++;
            }
        }
        return sum / categoryCount;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy