All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.function.categorization.BinaryVersusCategorizer Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                BinaryVersusCategorizer.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright April 08, 2009, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive 
 * license for use of this work by or on behalf of the U.S. Government. Export 
 * of this program may require a license from the United States Government. 
 * See CopyrightHistory.txt for complete details.
 * 
 */

package gov.sandia.cognition.learning.function.categorization;

import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.AbstractBatchLearnerContainer;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.DefaultWeightedValueDiscriminant;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Pair;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;

/**
 * An adapter that allows binary categorizers to be adapted for multi-category
 * problems by applying a binary categorizer to each pair of categories.
 *
 * @param   
 *      The type of the input to categorize.
 * @param   
 *      The type of the output categories.
 * @author  Justin Basilico
 * @since   3.0
 */
public class BinaryVersusCategorizer
    extends AbstractDiscriminantCategorizer
{

    /** Maps false-true category pairs . */
    protected Map, Evaluator>
        categoryPairsToEvaluatorMap;

    /**
     * Creates a new {@code BinaryVersusCategorizer}.
     */
    public BinaryVersusCategorizer()
    {
        this(new LinkedHashSet(),
            new LinkedHashMap, Evaluator>());
    }

    /**
     * Creates a new {@code BinaryVersusCategorizer} with the given
     * categories and an empty set of evaluators.
     *
     * @param   categories
     *      The possible output categories.
     */
    public BinaryVersusCategorizer(
        final Set categories)
    {
        this(categories,
            new LinkedHashMap, Evaluator>(
            (categories.size() * categories.size() / 2)));
    }

    /**
     * Creates a new {@code BinaryVersusCategorizer}.
     *
     * @param   categories
     *      The possible output categories.
     * @param   categoryPairsToEvaluatorMap
     *      The mapping of category pairs to their binary categorizer.
     */
    public BinaryVersusCategorizer(
        final Set categories,
        final Map, Evaluator> categoryPairsToEvaluatorMap)
    {
        super(categories);

        this.setCategoryPairsToEvaluatorMap(categoryPairsToEvaluatorMap);
    }

    @Override
    public BinaryVersusCategorizer clone()
    {
        BinaryVersusCategorizer result = (BinaryVersusCategorizer)
            super.clone();

        result.categoryPairsToEvaluatorMap =
            new LinkedHashMap, Evaluator>(
                this.categoryPairsToEvaluatorMap.size());
        for (Map.Entry, Evaluator> entry
            : this.categoryPairsToEvaluatorMap.entrySet())
        {
            result.categoryPairsToEvaluatorMap.put(
                ObjectUtil.cloneSmart(entry.getKey()),
                ObjectUtil.cloneSmart(entry.getValue()));
        }

        return result;
    }

    @Override
    public DefaultWeightedValueDiscriminant evaluateWithDiscriminant(
        final InputType input)
    {
        final int categoryCount = this.categories.size();

        if (categoryCount <= 0)
        {
            // No categories.
            return null;
        }
        else if (categoryCount == 1)
        {
            // There is only one category.
            return DefaultWeightedValueDiscriminant.create(
                CollectionUtil.getFirst(this.categories), 1.0);
        }

        // We are going to count the number of votes for each category.
        final DefaultDataDistribution results =
            new DefaultDataDistribution(categoryCount);

        // Go through all the pairs of evaluators.
        for (Map.Entry, Evaluator> entry
            : this.categoryPairsToEvaluatorMap.entrySet())
        {
            // Evaluate the binary categorizer for the two classes on the given
            // input.
            final Boolean category = entry.getValue().evaluate(input);

            if (category == null)
            {
                // Null values do not vote.
            }
            else if (!category)
            {
                // This belongs to the false (first) category.
                results.increment(entry.getKey().getFirst());
            }
            else
            {
                // This belongs to the true (second) category.
                results.increment(entry.getKey().getSecond());
            }
        }

        // The one with the most votes is the category we use.
        final CategoryType bestCategory = results.getMaxValueKey();
        final double bestFraction = results.getFraction(bestCategory);
        return DefaultWeightedValueDiscriminant.create(
            bestCategory, bestFraction);
    }

    /**
     * Gets the mapping of false-true category pairs to the binary categorizer
     * that distinguishes them.
     *
     * @return
     *      The mapping of category pairs to their binary categorizer.
     */
    public Map, Evaluator> getCategoryPairsToEvaluatorMap()
    {
        return this.categoryPairsToEvaluatorMap;
    }

    /**
     * Sets the mapping of false-true category pairs to the binary categorizer
     * that distinguishes them.
     *
     * @param   categoryPairsToEvaluatorMap
     *      The mapping of category pairs to their binary categorizer.
     */
    public void setCategoryPairsToEvaluatorMap(
        final Map, Evaluator> categoryPairsToEvaluatorMap)
    {
        this.categoryPairsToEvaluatorMap = categoryPairsToEvaluatorMap;
    }

    /**
     * A learner for the {@code BinaryVersusCategorizer}. It learns a
     * binary categorizer for each pair of categories.
     *
     * @param   
     *      The input to learn from and the input to the learned categorizer.
     * @param   
     *      The type of categories to learn from.
     */
    public static class Learner
        extends AbstractBatchLearnerContainer>, ? extends Evaluator>>
        implements SupervisedBatchLearner>
    {

        /**
         * Creates a new  {@code BinaryVersusCategorizer.Learner} with no
         * initial binary categorizer learner.
         */
        public Learner()
        {
            this(null);
        }

        /**
         * Creates a new {@code BinaryVersusCategorizer.Learner} with an
         * binary categorizer learner to learn category versus category.
         *
         * @param   learner
         *      The binary categorizer learner to use to learn decision
         *      functions between categories.
         */
        public Learner(
            BatchLearner>, ? extends Evaluator> learner)
        {
            super(learner);
        }
                
        @Override
        public BinaryVersusCategorizer learn(
            final Collection> data)
        {
            // Find the categories. We're going to look at pairs of categories
            // so we also make a list version of the set.
            final Set categories =
                DatasetUtil.findUniqueOutputs(data);
            final int categoryCount = categories.size();
            final ArrayList categoriesList =
                new ArrayList(categories);

            // Create the object to hold the result.
            final BinaryVersusCategorizer result =
                new BinaryVersusCategorizer(categories);
            for (int i = 0; i < categoryCount; i++)
            {
                // This is the false category.
                final CategoryType falseCategory = categoriesList.get(i);
                for (int j = i + 1; j < categoryCount; j++)
                {
                    // This is the true category.
                    final CategoryType trueCategory = categoriesList.get(j);

                    final ArrayList> versusData =
                        new ArrayList>();

                    for (InputOutputPair example
                        : data)
                    {
                        // The category of the label.
                        final CategoryType category = example.getOutput();

                        if (falseCategory.equals(category))
                        {
                            // This is an example belonging to the false
                            // category.
                            versusData.add(new DefaultInputOutputPair(
                                example.getInput(), false));
                        }
                        else if (trueCategory.equals(category))
                        {
                            // This is an example belonging to the true
                            // category.
                            versusData.add(new DefaultInputOutputPair(
                                example.getInput(), true));
                        }
                        // else - The example did not belong to either category.
                    }

                    // Learn on the binary data.
                    final Evaluator learned =
                        this.getLearner().learn(versusData);

                    // Add the learned categorizer.
                    result.categoryPairsToEvaluatorMap.put( DefaultPair.create(
                        falseCategory, trueCategory ), learned );
                }
            }

            // Returns the created adapter.
            return result;
        }

    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy