gov.sandia.cognition.learning.algorithm.perceptron.Perceptron Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: Perceptron.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright July 18, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.perceptron;
import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorFactoryContainer;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.NamedValue;
/**
* The Perceptron
class implements the standard Perceptron learning
* algorithm that learns a binary classifier based on vector input. This
* implementation also allows for margins to be defined in learning in order to
* find a hyperplane.
*
* @author Justin Basilico
* @since 2.0
*/
@CodeReview(
reviewer="Kevin R. Dixon",
date="2008-07-23",
changesNeeded=false,
comments={
"Added PublicationReference to Wikiepedia article.",
"Minor changes to javadoc.",
"Looks fine."
}
)
@PublicationReference(
author="Wikipedia",
title="Perceptron Learning algorithm",
type=PublicationType.WebPage,
year=2008,
url="http://en.wikipedia.org/wiki/Perceptron#Learning_algorithm"
)
public class Perceptron
extends AbstractAnytimeSupervisedBatchLearner
implements MeasurablePerformanceAlgorithm, VectorFactoryContainer
{
/** The default maximum number of iterations, {@value}. */
public static final int DEFAULT_MAX_ITERATIONS = 100;
/** The default positive margin, {@value}. */
public static final double DEFAULT_MARGIN_POSITIVE = 0.0;
/** The default negative margin, {@value}. */
public static final double DEFAULT_MARGIN_NEGATIVE = 0.0;
/** The positive margin to enforce. */
private double marginPositive;
/** The negative margin to enforce. */
private double marginNegative;
/** The VectorFactory to use to create vectors. */
private VectorFactory> vectorFactory;
/** The result categorizer. */
private LinearBinaryCategorizer result;
/** The number of errors on the most recent iteration. */
private int errorCount;
/**
* Creates a new instance of Perceptron.
*/
public Perceptron()
{
this(DEFAULT_MAX_ITERATIONS);
}
/**
* Creates a new instance of Perceptron with the given maximum number of
* iterations.
*
* @param maxIterations The maximum number of iterations.
*/
public Perceptron(
final int maxIterations)
{
this(maxIterations, DEFAULT_MARGIN_POSITIVE, DEFAULT_MARGIN_NEGATIVE);
}
/**
* Creates a new instance of Perceptron with the given parameters
*
* @param maxIterations The maximum number of iterations.
* @param marginPositive The positive margin to enforce.
* @param marginNegative The negative margin to enforce.
*/
public Perceptron(
final int maxIterations,
final double marginPositive,
final double marginNegative)
{
this(maxIterations, marginPositive, marginNegative,
VectorFactory.getDefault());
}
/**
* Creates a new instance of Perceptron with the given parameters
*
* @param maxIterations The maximum number of iterations.
* @param marginPositive The positive margin to enforce.
* @param marginNegative The negative margin to enforce.
* @param vectorFactory The VectorFactory to use to create the weight
* vector.
*/
public Perceptron(
final int maxIterations,
final double marginPositive,
final double marginNegative,
final VectorFactory> vectorFactory)
{
super(maxIterations);
this.setMarginPositive(marginPositive);
this.setMarginNegative(marginNegative);
this.setVectorFactory(vectorFactory);
}
@Override
public Perceptron clone()
{
final Perceptron clone = (Perceptron) super.clone();
clone.result = null;
clone.errorCount = 0;
return clone;
}
@Override
protected boolean initializeAlgorithm()
{
if (this.getData() == null)
{
// Error: No data to learn on.
return false;
}
// Computer the dimensionality of the data and ensure it is correct.
int dimensionality = DatasetUtil.getInputDimensionality(this.getData());
if (dimensionality < 0)
{
// There was no data.
return false;
}
DatasetUtil.assertInputDimensionalitiesAllEqual(this.getData());
// Initialize the result object.
this.setResult(new LinearBinaryCategorizer(
this.getVectorFactory().createVector(dimensionality),
0.0));
return true;
}
@Override
protected boolean step()
{
// Reset the number of errors for the new iteration.
this.setErrorCount(0);
// Loop over all the training instances.
for (InputOutputPair extends Vectorizable, ? extends Boolean> example
: this.getData())
{
if (example == null)
{
continue;
}
// Compute the predicted classification and get the actual
// classification.
final Vector input = example.getInput().convertToVector();
final boolean actual = example.getOutput();
final double prediction = this.result.evaluateAsDouble(input);
if ( (actual && prediction <= this.marginPositive)
|| (!actual && prediction >= -this.marginNegative))
{
// The classification was incorrect so we need to update
// the perceptron.
this.setErrorCount(this.getErrorCount() + 1);
final Vector weights = this.result.getWeights();
double bias = this.result.getBias();
if (actual)
{
// Update for a positive example so add to the
// weights and the bias.
weights.plusEquals(input);
bias += 1.0;
}
else
{
// Update for a negative example so subtract from
// the weights and the bias.
weights.minusEquals(input);
bias -= 1.0;
}
// The weights are updated by side-effect.
// Update the bias directly.
this.result.setBias(bias);
}
// else - The classification was correct, no need to update.
}
// Keep going while the error count is positive.
return this.getErrorCount() > 0;
}
@Override
protected void cleanupAlgorithm()
{
// Nothing to clean up.
}
/**
* Sets both the positive and negative margin to the same value.
*
* @param margin The new value for both the positive and negative margins.
*/
public void setMargin(
final double margin)
{
this.setMarginPositive(margin);
this.setMarginNegative(margin);
}
/**
* Gets the positive margin that is enforced.
*
* @return The positive margin that is enforced.
*/
public double getMarginPositive()
{
return this.marginPositive;
}
/**
* Sets the positive margin that is enforced.
*
* @param marginPositive The positive margin that is enforced.
*/
public void setMarginPositive(
final double marginPositive)
{
this.marginPositive = marginPositive;
}
/**
* Gets the negative margin that is enforced.
*
* @return The negative margin that is enforced.
*/
public double getMarginNegative()
{
return this.marginNegative;
}
/**
* Sets the negative margin that is enforced.
*
* @param marginNegative The negative margin that is enforced.
*/
public void setMarginNegative(
final double marginNegative)
{
this.marginNegative = marginNegative;
}
@Override
public VectorFactory> getVectorFactory()
{
return this.vectorFactory;
}
/**
* Sets the VectorFactory used to create the weight vector.
*
* @param vectorFactory The VectorFactory used to create the weight vector.
*/
public void setVectorFactory(
final VectorFactory> vectorFactory)
{
this.vectorFactory = vectorFactory;
}
@Override
public LinearBinaryCategorizer getResult()
{
return this.result;
}
/**
* Sets the object currently being result.
*
* @param result The object currently being result.
*/
protected void setResult(
final LinearBinaryCategorizer result)
{
this.result = result;
}
/**
* Gets the error count of the most recent iteration.
*
* @return The current error count.
*/
public int getErrorCount()
{
return this.errorCount;
}
/**
* Sets the error count of the most recent iteration.
*
* @param errorCount The current error count.
*/
protected void setErrorCount(
final int errorCount)
{
this.errorCount = errorCount;
}
@Override
public NamedValue getPerformance()
{
return new DefaultNamedValue("error count", this.getErrorCount());
}
}