All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.perceptron.kernel.KernelPerceptron Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                KernelPerceptron.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 *
 * Copyright July 18, 2007, Sandia Corporation.  Under the terms of Contract
 * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
 * or on behalf of the U.S. Government. Export of this program may require a
 * license from the United States Government. See CopyrightHistory.txt for
 * complete details.
 *
 */

package gov.sandia.cognition.learning.algorithm.perceptron.kernel;

import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.algorithm.perceptron.Perceptron;
import gov.sandia.cognition.learning.function.kernel.Kernel;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.DefaultKernelBinaryCategorizer;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.NamedValue;
import java.util.ArrayList;
import java.util.LinkedHashMap;

/**
 * The KernelPerceptron class implements the kernel version of 
 * the Perceptron algorithm. That is, it replaces the inner-product used in the 
 * standard Perceptron algorithm with a kernel method. This allows the 
 * algorithm to be used with data and a kernel that would map it into a 
 * high-dimensional space but does not need to since the kernel can compute the 
 * inner-product in the high-dimensional space without actually creating the 
 * vectors for it.
 *
 * @param    Input class of the {@code InputOutputPairs}
 * @author  Justin Basilico
 * @since   2.0
 * @see     Perceptron
 */
@CodeReview(
    reviewer="Kevin R. Dixon",
    date="2008-07-23",
    changesNeeded=false,
    comments={
        "Added PublicationReference to the original article.",
        "Minor changes to javadoc.",
        "Looks fine."
    }
)
@PublicationReference(
    author={
        "Yoav Freund",
        "Robert E. Schapire"
    },
    title="Large margin classification using the perceptron algorithm",
    publication="Machine Learning",
    type=PublicationType.Journal,
    year=1999,
    notes="Volume 37, Number 3",
    pages={277,296},
    url="http://www.cs.ucsd.edu/~yfreund/papers/LargeMarginsUsingPerceptron.pdf"
)
public class KernelPerceptron
    extends AbstractAnytimeSupervisedBatchLearner>
    implements MeasurablePerformanceAlgorithm
{

    /** The default maximum number of iterations, {@value}. */
    public static final int DEFAULT_MAX_ITERATIONS =
        Perceptron.DEFAULT_MAX_ITERATIONS;

    /** The default positive margin, {@value}. */
    public static final double DEFAULT_MARGIN_POSITIVE =
        Perceptron.DEFAULT_MARGIN_POSITIVE;

    /** The default negative margin, {@value}. */
    public static final double DEFAULT_MARGIN_NEGATIVE =
        Perceptron.DEFAULT_MARGIN_NEGATIVE;

    /** The kernel to use. */
    private Kernel kernel;

    /** The positive margin to enforce. */
    private double marginPositive;

    /** The negative margin to enforce. */
    private double marginNegative;

    /** The result categorizer. */
    private DefaultKernelBinaryCategorizer result;

    /** The number of errors on the most recent iteration. */
    private int errorCount;

    /** The mapping of weight objects to non-zero weighted examples 
     *  (support vectors). */
    private LinkedHashMap, DefaultWeightedValue> supportsMap;

    /**
     * Creates a new instance of KernelPerceptron.
     */
    public KernelPerceptron()
    {
        this(null);
    }

    /**
     * Creates a new KernelPerceptron with the given kernel.
     *
     * @param  kernel The kernel to use.
     */
    public KernelPerceptron(
        final Kernel kernel)
    {
        this(kernel, DEFAULT_MAX_ITERATIONS);
    }

    /**
     * Creates a new KernelPerceptron with the given kernel and maximum number
     * of iterations.
     *
     * @param  kernel The kernel to use.
     * @param  maxIterations The maximum number of iterations.
     */
    public KernelPerceptron(
        final Kernel kernel,
        final int maxIterations)
    {
        this(kernel, maxIterations,
            DEFAULT_MARGIN_POSITIVE, DEFAULT_MARGIN_NEGATIVE);
    }

    /**
     * Creates a new KernelPerceptron with the given parameters.
     *
     * @param  kernel The kernel to use.
     * @param  maxIterations The maximum number of iterations.
     * @param  marginPositive The positive margin to enforce.
     * @param  marginNegative The negative margin to enforce.
     */
    public KernelPerceptron(
        final Kernel kernel,
        final int maxIterations,
        final double marginPositive,
        final double marginNegative)
    {
        super(maxIterations);

        this.setKernel(kernel);
        this.setMarginPositive(marginPositive);
        this.setMarginNegative(marginNegative);

        this.setResult(null);
        this.setErrorCount(0);
        this.setSupportsMap(null);
    }

    @Override
    protected boolean initializeAlgorithm()
    {
        if (this.getData() == null)
        {
            // Error: No data to learn on.
            return false;
        }

        // Count the number of valid examples.
        int validCount = 0;
        for (InputOutputPair example : this.getData())
        {
            if (example != null)
            {
                validCount++;
            }
        }

        if (validCount <= 0)
        {
            // Nothing to perform learning on.
            return false;
        }

        // Set up the learning variables.
        this.setErrorCount(validCount);
        this.setSupportsMap(new LinkedHashMap, DefaultWeightedValue>());
        this.setResult(new DefaultKernelBinaryCategorizer(
            this.getKernel(), this.getSupportsMap().values(), 0.0));

        return true;
    }

    @Override
    protected boolean step()
    {
        // Reset the number of errors for the new iteration.
        this.setErrorCount(0);

        // Loop over all the training instances.
        for (InputOutputPair example : this.getData())
        {
            if (example == null)
            {
                continue;
            }

            // Compute the predicted classification and get the actual
            // classification.
            final InputType input = example.getInput();
            final boolean actual = example.getOutput();
            final double prediction = this.result.evaluateAsDouble(input);

            if ((actual && prediction <= +this.marginPositive) || (!actual && prediction >= -this.marginNegative))
            {
                // The classification was incorrect so we need to update
                // the perceptron.
                this.setErrorCount(this.getErrorCount() + 1);

                // We are going to update the weight for this example and the
                // global bias.
                double weight = 0.0;
                double bias = this.result.getBias();

                // If the weight exists get it from the support for the 
                // example.
                DefaultWeightedValue support =
                    this.supportsMap.get(example);
                if (support != null)
                {
                    weight = support.getWeight();
                }

                if (actual)
                {
                    // Update for a positive example so add to the
                    // weights and the bias.
                    weight += 1.0;
                    bias += 1.0;
                }
                else
                {
                    // Update for a negative example so subtract from
                    // the weights and the bias.
                    weight -= 1.0;
                    bias -= 1.0;
                }

                if (support == null)
                {
                    // Add a support for this example.
                    support = new DefaultWeightedValue(input, weight);
                    this.supportsMap.put(example, support);
                }
                else if (weight == 0.0)
                {
                    // This example is no longer a support.
                    this.supportsMap.remove(example);
                }
                else
                {
                    // Update the weight for the support.
                    support.setWeight(weight);
                }

                // Update the bias.
                this.result.setBias(bias);
            }
        // else - The classification was correct, no need to update.
        }

        // Keep going while the error count is positive.
        return this.getErrorCount() > 0;
    }

    @Override
    protected void cleanupAlgorithm()
    {
        if (this.getSupportsMap() != null)
        {
            // Make the result object have a more efficient backing collection
            // at the end.
            this.getResult().setExamples(
                new ArrayList>(
                    this.getSupportsMap().values()));

            this.setSupportsMap(null);
        }
    }

    /**
     * Gets the kernel to use.
     *
     * @return The kernel to use.
     */
    public Kernel getKernel()
    {
        return this.kernel;
    }

    /**
     * Sets the kernel to use.
     *
     * @param  kernel The kernel to use.
     */
    public void setKernel(
        final Kernel kernel)
    {
        this.kernel = kernel;
    }

    /**
     * Sets both the positive and negative margin to the same value.
     *
     * @param  margin The new value for both the positive and negative margins.
     */
    public void setMargin(
        final double margin)
    {
        this.setMarginPositive(margin);
        this.setMarginNegative(margin);
    }

    /**
     * Gets the positive margin that is enforced.
     *
     * @return The positive margin that is enforced.
     */
    public double getMarginPositive()
    {
        return this.marginPositive;
    }

    /**
     * Sets the positive margin that is enforced.
     *
     * @param  marginPositive The positive margin that is enforced.
     */
    public void setMarginPositive(
        final double marginPositive)
    {
        this.marginPositive = marginPositive;
    }

    /**
     * Gets the negative margin that is enforced.
     *
     * @return The negative margin that is enforced.
     */
    public double getMarginNegative()
    {
        return this.marginNegative;
    }

    /**
     * Sets the negative margin that is enforced.
     *
     * @param  marginNegative The negative margin that is enforced.
     */
    public void setMarginNegative(
        final double marginNegative)
    {
        this.marginNegative = marginNegative;
    }

    @Override
    public DefaultKernelBinaryCategorizer getResult()
    {
        return this.result;
    }

    /**
     * Sets the object currently being result.
     *
     * @param  result The object currently being result.
     */
    protected void setResult(
        final DefaultKernelBinaryCategorizer result)
    {
        this.result = result;
    }

    /**
     * Gets the error count of the most recent iteration.
     *
     * @return The current error count.
     */
    public int getErrorCount()
    {
        return this.errorCount;
    }

    /**
     * Sets the error count of the most recent iteration.
     *
     * @param  errorCount The current error count.
     */
    protected void setErrorCount(
        final int errorCount)
    {
        this.errorCount = errorCount;
    }

    /**
     * Gets the mapping of examples to weight objects (support vectors).
     *
     * @return The mapping of examples to weight objects.
     */
    protected LinkedHashMap, DefaultWeightedValue> getSupportsMap()
    {
        return supportsMap;
    }

    /**
     * Gets the mapping of examples to weight objects (support vectors).
     *
     * @param  supportsMap The mapping of examples to weight objects.
     */
    protected void setSupportsMap(
        final LinkedHashMap, DefaultWeightedValue> supportsMap)
    {
        this.supportsMap = supportsMap;
    }
    
    @Override
    public NamedValue getPerformance()
    {
        return new DefaultNamedValue("error count", this.getErrorCount());
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy