All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.perceptron.AggressiveRelaxedOnlineMaximumMarginAlgorithm Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                AggressiveRelaxedOnlineMaximumMarginAlgorithm.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry Learning Core
 * 
 * Copyright January 27, 2011, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive 
 * license for use of this work by or on behalf of the U.S. Government. Export 
 * of this program may require a license from the United States Government.
 *   */

package gov.sandia.cognition.learning.algorithm.perceptron;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.function.categorization.DefaultKernelBinaryCategorizer;
import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer;
import gov.sandia.cognition.learning.function.kernel.KernelUtil;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;

/**
 * An implementation of the Aggressive Relaxed Online Maximum Margin Algorithm
 * (AROMMA). It is an online learner for a linear binary categorizer that
 * also has a kernel form.
 * 
 * @author  Justin Basilico
 * @since   3.3.0
 */
@PublicationReference(
  title="Ultraconservative online algorithms for multiclass problems",
  author={"Koby Crammer", "Yoram Singer"},
  year=2003,
  type=PublicationType.Journal,
  publication="The Journal of Machine Learning Research",
  pages={951, 991},
  url="http://portal.acm.org/citation.cfm?id=944936")
public class AggressiveRelaxedOnlineMaximumMarginAlgorithm
    extends AbstractKernelizableBinaryCategorizerOnlineLearner
{

    /**
     * Creates a new {@code AggressiveRelaxedOnlineMaximumMarginAlgorithm}.
     */
    public AggressiveRelaxedOnlineMaximumMarginAlgorithm()
    {
        this(VectorFactory.getDefault());
    }

    /**
     * Creates a new {@code AggressiveRelaxedOnlineMaximumMarginAlgorithm} with
     * the given vector factory.
     *
     * @param   vectorFactory
     *      The vector factory to use.
     */
    public AggressiveRelaxedOnlineMaximumMarginAlgorithm(
        final VectorFactory vectorFactory)
    {
        super(vectorFactory);
    }

    @Override
    public void update(
        final LinearBinaryCategorizer target,
        final Vector input,
        final boolean label)
    {
        // Get the information about the example.
        final double actual = label ? +1.0 : -1.0;

        Vector weights = target.getWeights();
        if (weights == null)
        {
            // This is the first example, so initialize the weight vector.
            final double inputNorm = input.norm2Squared();
            weights = this.getVectorFactory().copyVector(input);
            weights.scaleEquals(actual / inputNorm);
            target.setWeights(weights);
        }
        else
        {
            // Predict the output as a double (negative values are false, positive
            // are true).
            final double prediction = target.evaluateAsDouble(input);

            final double error = actual * prediction;
            final double inputNorm = input.norm2Squared();
            final double weightsNorm = weights.norm2Squared();
            
            if ((1.0 > error) && (error >= inputNorm * weightsNorm))
            {
                weights.zero();
                if (inputNorm > 0.0)
                {
                    weights.plusEquals(input);
                    weights.scaleEquals(actual / inputNorm);
                }
            }
            else if (error < 1.0)
            {
                final double denominator = inputNorm * weightsNorm
                    - prediction * prediction;
                // Compute the update value.
                final double c = (inputNorm * weightsNorm - actual * prediction)
                    / denominator;
                final double d = (weightsNorm * (actual - prediction))
                    / denominator;

                weights.scaleEquals(c);
                weights.plusEquals(input.scale(d));
            }
            // else - Passive when there is no loss.
        }

    }

    @Override
    public  void update(
        final DefaultKernelBinaryCategorizer target,
        final InputType input,
        final boolean label)
    {
       // Get the information about the example.
        final double actual = label ? +1.0 : -1.0;

        if (target.getExamples().isEmpty())
        {
            // Initialize the target on the first update.
            final double inputNorm = target.getKernel().evaluate(input, input);

            if (inputNorm > 0.0)
            {
                target.add(input, actual / inputNorm);
            }
        }
        else
        {
            // Predict the output as a double (negative values are false, positive
            // are true).
            final double prediction = target.evaluateAsDouble(input);
            final double error = actual * prediction;
            final double inputNorm = target.getKernel().evaluate(input, input);
            final double weightsNorm = KernelUtil.norm2Squared(target);

            if ((1.0 > error) && (error >= inputNorm * weightsNorm))
            {
                target.getExamples().clear();
                
                if (inputNorm > 0.0)
                {
                    target.add(input, actual / inputNorm);
                }
            }
            else if (error < 1.0)
            {

                final double denominator = inputNorm * weightsNorm
                    - prediction * prediction;
                // Compute the update value.
                final double c = (inputNorm * weightsNorm - actual * prediction)
                    / denominator;
                final double d = (weightsNorm * (actual - prediction))
                    / denominator;

                KernelUtil.scaleEquals(target, c);
                target.add(input, d);
            }
            // else - Passive when there is no loss.
        }
    }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy