All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.perceptron.BatchMultiPerceptron Maven / Gradle / Ivy

/*
 * File:                BatchMultiPerceptron.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry Learning Core
 * 
 * Copyright April 21, 2011, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive 
 * license for use of this work by or on behalf of the U.S. Government. Export 
 * of this program may require a license from the United States Government. 
 */

package gov.sandia.cognition.learning.algorithm.perceptron;

import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer;
import gov.sandia.cognition.learning.function.categorization.LinearMultiCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorFactoryContainer;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.NamedValue;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Set;

/**
 * Implements a multi-class version of the standard batch Perceptron learning
 * algorithm. It learns over labeled examples that have a vector input and an
 * arbitrary output label. This version keeps a separate Perceptron as the
 * prototype for each category. When it makes an error, it subtracts the input
 * from the incorrect prototype and adds it to the correct prototype.
 * This implementation also allows a margin to be enforced, which
 * in the multi-class case means that the output value for the actual class
 * must be at least the given margin from the next highest class.
 *
 * @param 
 *      The type of output categories. Can be any type that has a valid
 *      equals and hashCode method.
 * @author  Justin Basilico
 * @since   3.2.0
 * @see     Perceptron
 * @see     OnlinePerceptron
 */
@PublicationReference(
    title="Ultraconservative Online Algorithms for Multiclass Problems",
    author={"Koby Crammer", "Yoram Singer"},
    year=2003,
    type=PublicationType.Journal,
    publication="Journal of Machine Learning Research",
    pages={951, 991},
    url="http://portal.acm.org/citation.cfm?id=944936")
public class BatchMultiPerceptron
    extends AbstractAnytimeSupervisedBatchLearner>
    implements MeasurablePerformanceAlgorithm, VectorFactoryContainer
{

    /** The default maximum number of iterations, {@value}. */
    public static final int DEFAULT_MAX_ITERATIONS = 100;

    /** The default minimum margin is {@value}. */
    public static final double DEFAULT_MIN_MARGIN = 0.0;

    /** The minimum margin to enforce. Must be non-negative. */
    protected double minMargin;

    /** The factory to create weight vectors. */
    protected VectorFactory vectorFactory;

    /** The linear categorizer created by the algorithm. */
    protected transient LinearMultiCategorizer result;

    /** The number of errors on the most recent iteration. */
    protected transient int errorCount;

    /**
     * Creates a new {@code BatchMultiPerceptron} with default parameters.
     */
    public BatchMultiPerceptron()
    {
        this(DEFAULT_MAX_ITERATIONS);
    }

    /**
     * Creates a new {@code BatchMultiPerceptron} with the given maximum number
     * of iterations and a default margin of 0.0.
     *
     * @param   maxIterations
     *      The maximum number of iterations. Must be positive.
     */
    public BatchMultiPerceptron(
        final int maxIterations)
    {
        this(maxIterations, DEFAULT_MIN_MARGIN);
    }

    /**
     * Creates a new {@code BatchMultiPerceptron} with the given maximum
     * number of iterations and margin to enforce.
     *
     * @param   maxIterations
     *      The maximum number of iterations. Must be positive.
     * @param   minMargin
     *      The minimum margin to enforce. Must be non-negative.
     */
    public BatchMultiPerceptron(
        final int maxIterations,
        final double minMargin)
    {
        this(maxIterations, minMargin, VectorFactory.getDefault());
    }

    /**
     * Creates a new {@code BatchMultiPerceptron} with the given parameters.
     *
     * @param   maxIterations
     *      The maximum number of iterations. Must be positive.
     * @param   minMargin
     *      The minimum margin to enforce. Must be non-negative.
     * @param   vectorFactory
     *      The factory to use to create weight vectors.
     */
    public BatchMultiPerceptron(
        final int maxIterations,
        final double minMargin,
        final VectorFactory vectorFactory)
    {
        super(maxIterations);

        this.setMinMargin(minMargin);
        this.setVectorFactory(vectorFactory);
    }

    @Override
    protected boolean initializeAlgorithm()
    {
        if (CollectionUtil.isEmpty(this.getData()))
        {
            // No data to learn from.
            return false;
        }

        // Get the dimensionality of the data.
        final int dimensionality = DatasetUtil.getInputDimensionality(
            this.getData());

        // Create the categorizer we will learn and create the prototypes for
        // each category.
        this.result = new LinearMultiCategorizer();
        final Set categories = DatasetUtil.findUniqueOutputs(
            this.getData());
        for (CategoryType category : categories)
        {
            final LinearBinaryCategorizer prototype = 
                new LinearBinaryCategorizer(
                this.getVectorFactory().createVector(dimensionality), 0.0);
            this.result.getPrototypes().put(category, prototype);
        }

        // The algorithm is now initialized.
        return true;
    }

    @Override
    protected boolean step()
    {
        // Reset the number of errors for the new iteration.
        this.setErrorCount(0);

        // Loop over all the training instances.
        for (InputOutputPair example
            : this.getData())
        {
            if (example == null)
            {
                // Ignore null examples.
                continue;
            }

            // Get the input as a Vector and the actual category.
            final Vector input = example.getInput().convertToVector();
            final CategoryType actual = example.getOutput();

            // See what the predicted category is.
            CategoryType predicted = null;
            double predictedScore = Double.NEGATIVE_INFINITY;
            for (CategoryType category : this.result.getCategories())
            {
                double score = this.result.evaluateAsDouble(input, category);
                if (this.minMargin != 0.0 && actual.equals(category))
                {
                    // Enforce a margin on the correct category.
                    score -= this.minMargin;
                }

                if (score > predictedScore)
                {
                    // This is the predicted category.
                    predicted = category;
                    predictedScore = score;
                }
            }

            // See if the algorithm was correct or not.
            final boolean correct = ObjectUtil.equalsSafe(actual, predicted);
            if (!correct)
            {
                // The classification was incorrect so we need to update.
                this.setErrorCount(this.getErrorCount() + 1);

                // Increment the prototype for the actual category.
                final LinearBinaryCategorizer actualPrototype =
                    this.result.getPrototypes().get(actual);
                actualPrototype.getWeights().plusEquals(input);
                actualPrototype.setBias(actualPrototype.getBias() + 1.0);

                // Decrement the prototype for the predicted category.
                final LinearBinaryCategorizer predictedPrototype =
                    this.result.getPrototypes().get(predicted);
                predictedPrototype.getWeights().minusEquals(input);
                predictedPrototype.setBias(predictedPrototype.getBias() - 1.0);
            }
            // else - Not an error, no need to update.
        }

        // Keep going while the error count is positive.
        return this.getErrorCount() > 0;
    }

    @Override
    protected void cleanupAlgorithm()
    {
        // Nothing to clean up.
    }

    @Override
    public LinearMultiCategorizer getResult()
    {
        return this.result;
    }

    /**
     * Sets the result of the algorithm.
     *
     * @param   result
     *      The result of the algorithm.
     */
    protected void setResult(
        final LinearMultiCategorizer result)
    {
        this.result = result;
    }

    /**
     * Gets the minimum margin to enforce. Any value less than or equal to
     * this is considered an error for the algorithm.
     *
     * @return
     *      The minimum margin. Cannot be negative.
     */
    public double getMinMargin()
    {
        return this.minMargin;
    }

    /**
     * Gets the minimum margin to enforce. Any value less than or equal to
     * this is considered an error for the algorithm.
     *
     * @param   minMargin
     *      The minimum margin. Cannot be negative.
     */
    public void setMinMargin(
        final double minMargin)
    {
        ArgumentChecker.assertIsNonNegative("minMargin", minMargin);
        this.minMargin = minMargin;
    }

    /**
     * Gets the VectorFactory used to create the weight vector.
     *
     * @return The VectorFactory used to create the weight vector.
     */
    @Override
    public VectorFactory getVectorFactory()
    {
        return this.vectorFactory;
    }

    /**
     * Sets the VectorFactory used to create the weight vector.
     *
     * @param  vectorFactory The VectorFactory used to create the weight vector.
     */
    public void setVectorFactory(
        final VectorFactory vectorFactory)
    {
        this.vectorFactory = vectorFactory;
    }


    /**
     * Gets the error count of the most recent iteration.
     *
     * @return The current error count.
     */
    public int getErrorCount()
    {
        return this.errorCount;
    }

    /**
     * Sets the error count of the most recent iteration.
     *
     * @param  errorCount The current error count.
     */
    protected void setErrorCount(
        final int errorCount)
    {
        this.errorCount = errorCount;
    }

    /**
     * Gets the performance, which is the error count on the last iteration.
     *
     * @return The performance of the algorithm.
     */
    @Override
    public NamedValue getPerformance()
    {
        return new DefaultNamedValue("error count",
            this.getErrorCount());
    }
    
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy