All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.performance.categorization.DefaultConfusionMatrix Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                DefaultConfusionMatrix.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry Learning Core
 * 
 * Copyright January 11, 2011, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive 
 * license for use of this work by or on behalf of the U.S. Government. Export 
 * of this program may require a license from the United States Government. 
 */

package gov.sandia.cognition.learning.performance.categorization;

import gov.sandia.cognition.learning.data.TargetEstimatePair;
import gov.sandia.cognition.math.MutableDouble;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.Pair;
import gov.sandia.cognition.util.Summarizer;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;

/**
 * A default implementation of the {@code ConfusionMatrix} interface. It is
 * backed by a two-level map storing the category object counts, making a
 * sparse representation.
 *
 * @param 
 *      The type of the category object over the confusion matrix.
 * @author  Justin Basilico
 * @since   3.1
 */
public class DefaultConfusionMatrix
    extends AbstractConfusionMatrix
{

    /** The backing map of confusion matrix entries. The first key is the
     *  actual category and the second is the predicted category. */
    protected Map> confusions;

    /**
     * Creates a new, empty {@code DefaultConfusionMatrix}.
     */
    public DefaultConfusionMatrix()
    {
        super();

        this.confusions = new LinkedHashMap>();
    }

    /**
     * Creates a copy of a given confusion matrix.
     *
     * @param   other
     *      The other confusion matrix to copy.
     */
    public DefaultConfusionMatrix(
        final ConfusionMatrix other)
    {
        this();

        this.addAll(other);
    }

    @Override
    public DefaultConfusionMatrix clone()
    {
        @SuppressWarnings("unchecked")
        final DefaultConfusionMatrix clone =
            (DefaultConfusionMatrix) super.clone();

        if (this.confusions != null)
        {
            clone.confusions = new LinkedHashMap>(
                this.confusions.size());
            for (Map.Entry> outerEntry
                : this.confusions.entrySet())
            {
                final LinkedHashMap categoryMap =
                    new LinkedHashMap(
                        outerEntry.getValue().size());
                clone.confusions.put(outerEntry.getKey(), categoryMap);
                for (Map.Entry innerEntry
                    : outerEntry.getValue().entrySet())
                {
                    categoryMap.put(innerEntry.getKey(),
                        innerEntry.getValue().clone());
                }
            }
        }

        return clone;
    }

    @Override
    public void add(
        final CategoryType target,
        final CategoryType estimate,
        final double value)
    {
        Map subMap = confusions.get(target);

        if (subMap == null)
        {
            subMap = new HashMap();
            this.confusions.put(target, subMap);
        }

        MutableDouble entry = subMap.get(estimate);
        if (entry == null)
        {
            entry = new MutableDouble(value);
            subMap.put(estimate, entry);
        }
        else
        {
            entry.value += value;
        }
    }

    @Override
    public double getCount(
        final CategoryType target,
        final CategoryType estimate)
    {
        Map subMap = confusions.get(target);

        if (subMap == null)
        {
            return 0.0;
        }
        else
        {
            MutableDouble entry = subMap.get(estimate);

            if (entry == null)
            {
                return 0.0;
            }
            else
            {
                return entry.getValue();
            }
        }

    }

    @Override
    public double getActualCount(
        final CategoryType target)
    {
        Map subMap = confusions.get(target);

        if (subMap == null)
        {
            return 0.0;
        }

        double result = 0.0;
        for (MutableDouble value : subMap.values())
        {
            result += value.getValue();
        }

        return result;
    }

    @Override
    public void clear()
    {
        this.confusions.clear();
    }

    @Override
    public Set getCategories()
    {
        final LinkedHashSet result =
            new LinkedHashSet();
        result.addAll(this.getActualCategories());
        result.addAll(this.getPredictedCategories());
        return result;
    }

    @Override
    public Set getActualCategories()
    {
        return this.confusions.keySet();
    }

    @Override
    public Set getPredictedCategories()
    {
        final LinkedHashSet estimates = new LinkedHashSet(
            this.confusions.size());

        for (Map estimateCounts
            : this.confusions.values())
        {
            estimates.addAll(estimateCounts.keySet());
        }

        return estimates;
    }

    @Override
    public Set getPredictedCategories(
        final CategoryType target)
    {
        Map subMap = confusions.get(target);

        if (subMap == null)
        {
            return Collections.emptySet();
        }
        else
        {
            return subMap.keySet();
        }
    }

    @Override
    public String toString()
    {
        return this.confusions.toString();
    }

    /**
     * Creates a new {@code DefaultConfusionMatrix} from the given
     * actual-predicted pairs.
     *
     * @param   
     *      The category type.
     * @param   pairs
     *      The actual-category pairs.
     * @return
     *      A new confusion matrix populated from the given actual-category
     *      pairs.
     */
    public static  DefaultConfusionMatrix createUnweighted(
        final Collection> pairs)
    {
        final DefaultConfusionMatrix result =
            new DefaultConfusionMatrix();
        for (TargetEstimatePair item
            : pairs)
        {
            result.add(item.getTarget(), item.getEstimate());
        }
        return result;
    }

    /**
     * Creates a new {@code DefaultConfusionMatrix} from the given
     * actual-predicted pairs.
     *
     * @param   
     *      The category type.
     * @param   pairs
     *      The actual-category pairs.
     * @return
     *      A new confusion matrix populated from the given actual-category
     *      pairs.
     */
    public static  DefaultConfusionMatrix createFromActualPredictedPairs(
        final Collection> pairs)
    {
        final DefaultConfusionMatrix result =
            new DefaultConfusionMatrix();
        for (Pair pair
            : pairs)
        {
            result.add(pair.getFirst(), pair.getSecond());
        }
        return result;
    }

    /**
     * A confusion matrix summarizer that summarizes actual-predicted pairs.
     *
     * @param   
     *      The type of category of the summarizer.
     */
    public static class ActualPredictedPairSummarizer
        extends AbstractCloneableSerializable
        implements Summarizer, DefaultConfusionMatrix>
    {

        /**
         * Creates a new {@code CombineSummarizer}.
         */
        public ActualPredictedPairSummarizer()
        {
            super();
        }

        @Override
        public DefaultConfusionMatrix summarize(
            final Collection> data)
        {
            return createFromActualPredictedPairs(data);
        }

    }

    /**
     * A confusion matrix summarizer that adds together confusion matrices.
     *
     * @param   
     *      The type of category of the summarizer.
     */
    public static class CombineSummarizer
        extends AbstractCloneableSerializable
        implements Summarizer, DefaultConfusionMatrix>
    {

        /**
         * Creates a new {@code CombineSummarizer}.
         */
        public CombineSummarizer()
        {
            super();
        }

        @Override
        public DefaultConfusionMatrix summarize(
            final Collection> data)
        {
            final DefaultConfusionMatrix result =
                new DefaultConfusionMatrix();

            for (ConfusionMatrix item : data)
            {
                result.addAll(item);
            }

            return result;
        }

    }
    
    /**
     * A factory for default confusion matrices.
     * 
     * @param    
     *      The type of category that the confusion is computed over.
     */
    public static class Factory
        extends AbstractCloneableSerializable
        implements gov.sandia.cognition.factory.Factory>
    {

        /**
         * Creates a new {@code Factory}.
         */
        public Factory()
        {
            super();
        }

        @Override
        public DefaultConfusionMatrix create()
        {
            return new DefaultConfusionMatrix();
        }
        
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy