All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.function.categorization.WinnerTakeAllCategorizer Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                WinnerTakeAllCategorizer.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright April 07, 2009, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive 
 * license for use of this work by or on behalf of the U.S. Government. Export 
 * of this program may require a license from the United States Government. 
 * See CopyrightHistory.txt for complete details.
 * 
 */

package gov.sandia.cognition.learning.function.categorization;

import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.AbstractBatchLearnerContainer;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.data.DefaultWeightedValueDiscriminant;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorFactoryContainer;
import gov.sandia.cognition.math.matrix.Vectorizable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Set;

/**
 * Adapts an evaluator that outputs a vector to be used as a categorizer.
 * @param   The type of the input the categorizer can use.
 * @param   The type of category output by the categorizer.
 * @author  Justin Basilico
 * @since   3.0
 */
public class WinnerTakeAllCategorizer
    extends AbstractDiscriminantCategorizer
{
    
    /** The evaluator that outputs a vector to return. */
    protected Evaluator evaluator;

    /**
     * Creates a new {@code WinnerTakesAllCategorizer}.
     */
    public WinnerTakeAllCategorizer()
    {
        this(null, new LinkedHashSet());
    }

    /**
     * Creates a new {@code WinnerTakesAllCategorizer}.
     *
     * @param   evaluator
     *      The evaluator that this class adapts.
     * @param   categories
     *      The set of categories.
     */
    public WinnerTakeAllCategorizer(
        final Evaluator evaluator,
        final Set categories)
    {
        super(categories);

        this.setEvaluator(evaluator);
    }

    /**
     * Evaluates the categorizer and returns the category along with a weight.
     *
     * @param   input
     *      The input to evaluate.
     * @return
     *      The output plus its weight.
     */
    @Override
    public DefaultWeightedValueDiscriminant evaluateWithDiscriminant(
        final InputType input)
    {
        final Vector output = this.evaluator.evaluate(input).convertToVector();
        return this.findBestCategory(output);
    }
    
    /**
     * Finds the best category (and its output value) from the given vector
     * of outputs from a vector evaluator. Ties are broken by taking the first
     * category with the highest value.
     *
     * @param   output
     *      The vector of outputs, one corresponding to each category.
     * @return
     *      The category with the highest output value. Ties are broken by
     *      taking the first category with the highest value.
     */
    public DefaultWeightedValueDiscriminant findBestCategory(
        final Vector output)
    {
        output.assertDimensionalityEquals(this.categories.size());

        // Loop througn and find the best category.
        CategoryType best = null;
        double bestValue = Double.NEGATIVE_INFINITY;
        int index = 0;
        for (CategoryType category : this.categories)
        {
            final double value = output.getElement(index);
            
            if (best == null || value > bestValue)
            {
                best = category;
                bestValue = value;
            }
            index++;
        }

        // Return the best category.
        return new DefaultWeightedValueDiscriminant(best, bestValue);
    }

    /**
     * Gets the wrapped evaluator.
     *
     * @return
     *      The wrapped evaluator.
     */
    public Evaluator getEvaluator()
    {
        return this.evaluator;
    }

    /**
     * Sets the wrapped evaluator.
     *
     * @param   evaluator
     *      The wrapped evaluator.
     */
    public void setEvaluator(
        final Evaluator evaluator)
    {
        this.evaluator = evaluator;
    }

    @Override
    public void setCategories(
        final Set categories)
    {
        // Just making this publicly visiable.
        super.setCategories(categories);
    }

    /**
     * A learner for the adapter. Makes it so that learning algorithms that
     * produce evaluators whose outputs are vectors to be used for
     * categorization.
     *
     * @param   
     *      The type of the input data for the learner.
     * @param   
     *      The type of the output categories.
     */
    public static class Learner
        extends AbstractBatchLearnerContainer>, ? extends Evaluator>>
        implements SupervisedBatchLearner>,
            VectorFactoryContainer
    {

        /** The vector factory used. */
        protected VectorFactory vectorFactory;

        /**
         * Creates a new learner adapter.
         */
        public Learner()
        {
            this(null);
        }

        /**
         * Creates a new learner adapter with the given internal learner.
         *
         * @param   learner
         *      The learner to adapt.
         */
        public Learner(
            final BatchLearner>, Evaluator> learner)
        {
            super(learner);

            this.setVectorFactory(VectorFactory.getDefault());
        }

        public WinnerTakeAllCategorizer learn(
            final Collection> data)
        {
            // First figure out all of the categories.
            final LinkedHashMap categoryIndices =
                new LinkedHashMap();
            for (InputOutputPair example : data)
            {
                final CategoryType category = example.getOutput();

                if (!categoryIndices.containsKey(category))
                {
                    final int index = categoryIndices.size();
                    categoryIndices.put(category, index);
                }
            }

            // Now convert the dataset to have the output be vectors.
            final int categoryCount = categoryIndices.size();
            final ArrayList> vectorData =
                new ArrayList>(data.size());
            for (InputOutputPair example
                : data)
            {
                final CategoryType category = example.getOutput();
                final int index = categoryIndices.get(category);


// TODO: Make use of the boolean encoding code from the data package at some point.
                final Vector output = this.getVectorFactory().createVector(
                    categoryCount, -1.0);
                output.setElement(index, +1.0);
                vectorData.add(new DefaultInputOutputPair(
                    example.getInput(), output));
            }

            // Do the learning.
            final Evaluator learned =
                this.getLearner().learn(vectorData);

            final LinkedHashSet categories =
                new LinkedHashSet(categoryIndices.keySet());
            return new WinnerTakeAllCategorizer(
                learned, categories);
        }

        /**
         * Gets the vector factory.
         *
         * @return
         *      The vector factory.
         */
        public VectorFactory getVectorFactory()
        {
            return this.vectorFactory;
        }

        /**
         * Sets the vector factory.
         *
         * @param   vectorFactory
         *      The vector factory.
         */
        public void setVectorFactory(
            final VectorFactory vectorFactory)
        {
            this.vectorFactory = vectorFactory;
        }

    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy