All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.perceptron.OnlineVotedPerceptron Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                OnlineVotedPerceptron.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry Learning Core
 * 
 * Copyright October 20, 2010, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive 
 * license for use of this work by or on behalf of the U.S. Government. Export 
 * of this program may require a license from the United States Government.
 *
 */

package gov.sandia.cognition.learning.algorithm.perceptron;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractSupervisedBatchAndIncrementalLearner;
import gov.sandia.cognition.learning.algorithm.ensemble.WeightedBinaryEnsemble;
import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorFactoryContainer;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.DefaultWeightedValue;

/**
 * An online version of the Voted-Perceptron algorithm. It is similar to the
 * typical Perceptron algorithm except that it creates multiple Perceptrons,
 * and combines them together in a weighted vote. Whenever a mistake is made,
 * a new Perceptron is created by modifying the previous one and given a weight
 * of 1. When it gets an example correct, it simply increments the weight on
 * the most recent one.
 * 
 * @author  Justin Basilico
 * @since   3.1
 */
@PublicationReference(
    title="Large Margin Classification Using the Perceptron Algorithm",
    author={"Yoav Freund", "Robert E. Schapire" },
    year=1999,
    type=PublicationType.Journal,
    publication="Machine Learning",
    pages={277, 296},
    url="http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.48.8200")
public class OnlineVotedPerceptron
    extends AbstractSupervisedBatchAndIncrementalLearner>
    implements VectorFactoryContainer
{

    /** The factory to create weight vectors. */
    protected VectorFactory vectorFactory;

    /**
     * Creates a new {@code OnlinePerceptron}.
     */
    public OnlineVotedPerceptron()
    {
        this(VectorFactory.getDenseDefault());
    }

    /**
     * Creates a new {@code OnlinePerceptron} with the given vector factory.
     *
     * @param   vectorFactory
     *      The vector factory to use to create the weight vectors.
     */
    public OnlineVotedPerceptron(
        final VectorFactory vectorFactory)
    {
        super();

        this.setVectorFactory(vectorFactory);
    }

    @Override
    public WeightedBinaryEnsemble createInitialLearnedObject()
    {
        return new WeightedBinaryEnsemble();
    }

    @Override
    public void update(
        final WeightedBinaryEnsemble target,
        final Vectorizable input,
        final Boolean output)
    {
        if (input != null && output != null)
        {
            this.update(target, input.convertToVector(), (boolean) output);
        }
    }
    
    /**
     * The {@code update} method updates an object of {@code ResultType} using
     * the given a new supervised input-output pair, using some form of
     * "learning" algorithm.
     *
     * @param   target
     *      The object to update.
     * @param   input
     *      The supervised input vector to learn from.
     * @param   actual
     *      The supervised output label to learn from.
     */
    public void update(
        final WeightedBinaryEnsemble target,
        final Vector input,
        final boolean actual)
    {
        // Predict the output as a double (negative values are false, positive
        // are true).
        final double prediction = target.evaluateAsDouble(input);

        // The computation that we do is based on using the last member in
        // the ensemble.
        final DefaultWeightedValue lastMember =
            getLastMember(target);

        // Make an update if there was an error.
        final boolean correct =
               (actual && prediction > 0.0)
            || (!actual && prediction < 0.0);
        if (correct)
        {
            // There was no error made, so increase the weight on the latest
            // member of the ensemble.
            // Note: It should never reach here when lastMember is null because
            // then the prediction has to be zero.
            lastMember.setWeight(lastMember.getWeight() + 1.0);
        }
        else
        {
            final LinearBinaryCategorizer next;
            if (lastMember == null)
            {
                // This is the very first data point we've seen, so we need
                // to create an initial categorizer.
                next = new LinearBinaryCategorizer(
                    this.getVectorFactory().createVector(
                        input.getDimensionality()), 0.0);
            }
            else
            {
                // Clone the previous member.
                next = lastMember.getValue().clone();
            }

            if (actual)
            {
                // An error with the true (positive) category.
                next.getWeights().plusEquals(input);
                next.setBias(next.getBias() + 1.0);
            }
            else
            {
                // An error with the false (negative) category.
                next.getWeights().minusEquals(input);
                next.setBias(next.getBias() - 1.0);
            }

            // Add the new member to the ensemble.
            target.add(next, 1.0);
        }
    }

    /**
     * Gets the last member in the ensemble. This is the one used by the
     * algorithm.
     *
     * @param   ensemble
     *      The ensemble to get the last member from.
     * @return
     *      The last member in the ensemble, or null if it is empty.
     */
    public static DefaultWeightedValue getLastMember(
        final WeightedBinaryEnsemble ensemble)
{
        final int ensembleSize = ensemble.getMembers().size();
        if (ensembleSize <= 0)
        {
            return null;
        }
        else
        {
            return (DefaultWeightedValue)
                ensemble.getMembers().get(ensembleSize - 1);}
    }

    /**
     * Gets the VectorFactory used to create the weight vector.
     *
     * @return The VectorFactory used to create the weight vector.
     */
    public VectorFactory getVectorFactory()
    {
        return this.vectorFactory;
    }

    /**
     * Sets the VectorFactory used to create the weight vector.
     *
     * @param  vectorFactory The VectorFactory used to create the weight vector.
     */
    public void setVectorFactory(
        final VectorFactory vectorFactory)
    {
        this.vectorFactory = vectorFactory;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy