All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.perceptron.AbstractLinearCombinationOnlineLearner Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                AbstractLinearCombinationOnlineLearner.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry Learning Core
 * 
 * Copyright March 28, 2011, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive 
 * license for use of this work by or on behalf of the U.S. Government. Export 
 * of this program may require a license from the United States Government.
 *
 */

package gov.sandia.cognition.learning.algorithm.perceptron;

import gov.sandia.cognition.learning.function.categorization.DefaultKernelBinaryCategorizer;
import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer;
import gov.sandia.cognition.learning.function.kernel.Kernel;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.util.DefaultWeightedValue;

/**
 * An abstract class for online learning of linear binary categorizers that
 * take the form of a weighted sum of inputs. It has utility methods for
 * splitting up the computation into three steps and provides both linear
 * and kernel methods for learning.
 * 
 * @author  Justin Basilico
 * @since   3.3.0
 */
public abstract class AbstractLinearCombinationOnlineLearner
    extends AbstractKernelizableBinaryCategorizerOnlineLearner
{
    /** An option controlling whether or not the bias is updated or not.
     *  Individual algorithm implementations choose the default value for this.
     */
    protected boolean updateBias;
    
    /**
     * Creates a new {@code AbstractLinearCombinationOnlineLearner} with
     * default parameters.
     *
     * @param   updateBias
     *      Whether or not the bias should be updated by the algorithm.
     */
    public AbstractLinearCombinationOnlineLearner(
        final boolean updateBias)
    {
        this(updateBias, VectorFactory.getDefault());
    }

    /**
     * Creates a new {@code AbstractLinearCombinationOnlineLearner} with
     * the given parameters.
     *
     * @param   updateBias
     *      Whether or not the bias should be updated by the algorithm.
     * @param   vectorFactory
     *      The vector factory to use.
     */
    public AbstractLinearCombinationOnlineLearner(
        final boolean updateBias,
        final VectorFactory vectorFactory)
    {
        super(vectorFactory);

        this.setUpdateBias(updateBias);
    }

    @Override
    public void update(
        final LinearBinaryCategorizer target,
        final Vector input,
        final boolean label)
    {

        Vector weights = target.getWeights();
        if (weights == null)
        {
            // This is the first example, so initialize the weight vector.
            this.initialize(target, input, label);
            weights = target.getWeights();
        }
        // else - Use the existing weights.

        // Predict the output as a double (negative values are false, positive
        // are true).
        final double predicted = target.evaluateAsDouble(input);
        final double actual = label ? +1.0 : -1.0;

        // Compute the update.
        final double update = this.computeUpdate(target, input, label,
            predicted);

        // Now compute the decay before we've applied the update.
        final double decay = this.computeDecay(
            target, input, label, predicted, update);

        // Do the decaying of old data.
        if (decay != 1.0)
        {
            if (decay == 0.0)
            {
                weights.zero();
            }
            else
            {
                weights.scaleEquals(decay);
            }
        }

        // Add the new value.
        if (update != 0.0)
        {
            if (update == 1.0)
            {
                // Special case for updating by 1 to avoid copying memory.
                if (label)
                {
                    weights.plusEquals(input);
                }
                else
                {
                    weights.minusEquals(input);
                }
            }
            else
            {
                weights.plusEquals(input.scale(update * actual));
            }

        }
        // else - Not an error.

        // Update the target.
        target.setWeights(weights);
        if (this.updateBias)
        {
            final double bias = target.getBias() * decay + actual * update;
            target.setBias(bias);
        }

        // Compute the rescaling.
        final double rescaling = this.computeRescaling(target, input, label,
            predicted, update, decay);
        if (rescaling != 1.0)
        {
            weights.scaleEquals(rescaling);
            
            target.setWeights(weights);
            if (this.updateBias)
            {
                final double bias = target.getBias() * rescaling;
                target.setBias(bias);
            }
        }
        // else - No need to rescale.

    }

    @Override
    public  DefaultKernelBinaryCategorizer createInitialLearnedObject(
        final Kernel kernel)
    {
        return new DefaultKernelBinaryCategorizer(
            kernel);
    }

    @Override
    public  void update(
        final DefaultKernelBinaryCategorizer target,
        final InputType input,
        final boolean output)
    {
        // Get the information about the example.
        final boolean label = output;

        if (target.getExamples().isEmpty())
        {
            // Target is not initialize, so initialize it.
            this.initialize(target, input, output);
        }

        // Predict the output as a double (negative values are false, positive
        // are true).
        final double predicted = target.evaluateAsDouble(input);
        final double actual = label ? +1.0 : -1.0;

        // Compute the update.
        final double update = this.computeUpdate(target, input, label,
            predicted);

        // Now compute the decay before we've applied the update.
        final double decay = this.computeDecay(
            target, input, label, predicted, update);

        // Do the decaying of old data.
        if (decay != 1.0)
        {
            if (decay == 0.0)
            {
                target.getExamples().clear();
            }
            else
            {
                for (DefaultWeightedValue weighted
                    : target.getExamples())
                {
                    weighted.setWeight(decay * weighted.getWeight());
                }
            }
        }

        // Add the new value.
        if (update != 0.0)
        {
            target.add(input, update * actual);
        }
        // else - Not an error.

        // Update the target.
        if (this.updateBias)
        {
            final double bias = target.getBias() * decay + actual * update;
            target.setBias(bias);
        }

        // Compute the rescaling.
        final double rescaling = this.computeRescaling(target, input, label,
            predicted, update, decay);
        if (rescaling != 1.0)
        {
            for (DefaultWeightedValue weighted
                : target.getExamples())
            {
                weighted.setWeight(rescaling * weighted.getWeight());
            }

            if (this.updateBias)
            {
                final double bias = target.getBias() * rescaling;
                target.setBias(bias);
            }
        }
        // else - No need to rescale.

    }

    /**
     * Initializes the linear binary categorizer. Can be overridden.
     * The default implementation just sets the weights to the zero vector.
     *
     * @param   target
     *      The categorizer to initialize.
     * @param   input
     *      The first input seen.
     * @param   actualCategory
     *      The actual category of the first input.
     */
    protected void initialize(
        final LinearBinaryCategorizer target,
        final Vector input,
        final boolean actualCategory)
    {
        Vector weights = this.getVectorFactory().createVector(
                input.getDimensionality());
        target.setWeights(weights);
    }

    /**
     * Compute the update weight in the linear case. Must be implemented by
     * subclasses.
     *
     * @param   target
     *      Target to compute the update for.
     * @param   input
     *      Input to use in computing the update.
     * @param   actualCategory
     *      The actual category of the input.
     * @param   predicted
     *      The predicted category of the input.
     * @return
     *      The update weight for how much to add the input to the target.
     *      May be zero if no update is needed.
     */
    protected abstract double computeUpdate(
        final LinearBinaryCategorizer target,
        final Vector input,
        final boolean actualCategory,
        final double predicted);

    /**
     * Computes the decay scalar for the existing weight vector. Can be
     * overridden. The default implementation just returns 1.0, which means no
     * change. Typically this will be a value between 0.0 and 1.0.
     *
     * @param   target
     *      Target to compute the update for.
     * @param   input
     *      Input to use in computing the update.
     * @param   actualCategory
     *      The actual category of the input.
     * @param   predicted
     *      The predicted category of the input.
     * @param   update
     *      The value from the computeUpdate step.
     * @return
     *      The decay to apply to the weight vector. Usually between 0.0 and
     *      1.0.
     */
    protected double computeDecay(
        final LinearBinaryCategorizer target,
        final Vector input,
        final boolean actualCategory,
        final double predicted,
        final double update)
    {
        // Default is no decay.
        return 1.0;
    }

    /**
     * Computes the rescaling for the new weight vector. Can be overridden.
     * The default implementation just returns 1.0, which means no change.
     * Typically this will be a value between 0.0 and 1.0.
     *
     * @param   target
     *      Target to compute the update for.
     * @param   input
     *      Input to use in computing the update.
     * @param   actualCategory
     *      The actual category of the input.
     * @param   predicted
     *      The predicted category of the input.
     * @param   update
     *      The value from the computeUpdate step.
     * @param   decay
     *      The value from the computeDecay step.
     * @return
     *      The rescaling to apply to the weight vector. Usually between 0.0 and
     *      1.0.
     */
    protected double computeRescaling(
        final LinearBinaryCategorizer target,
        final Vector input,
        final boolean actualCategory,
        final double predicted,
        final double update,
        final double decay)
    {
        // Default is no rescaling.
        return 1.0;
    }

    /**
     * Initializes the kernel binary categorizer. Can be overridden.
     * The default implementation does nothing.
     *
     * @param   
     *      The input value for learning.
     * @param   target
     *      The categorizer to initialize.
     * @param   input
     *      The first input seen.
     * @param   actualCategory
     *      The actual category of the first input.
     */
    protected  void initialize(
        final DefaultKernelBinaryCategorizer target,
        final InputType input,
        final boolean actualCategory)
    {
        // Nothing to initialize.
    }
    /**
     * Compute the update weight in the linear case. Must be implemented by
     * subclasses.
     *
     * @param   
     *      The input value for learning.
     * @param   target
     *      Target to compute the update for.
     * @param   input
     *      Input to use in computing the update.
     * @param   actualCategory
     *      The actual category of the input.
     * @param   predicted
     *      The predicted category of the input.
     * @return
     *      The update weight for how much to add the input to the target.
     *      May be zero if no update is needed.
     */
    protected abstract  double computeUpdate(
        final DefaultKernelBinaryCategorizer target,
        final InputType input,
        final boolean actualCategory,
        final double predicted);

    /**
     * Computes the decay scalar for the existing weights. Can be overridden.
     * The default implementation just returns 1.0, which means no change.
     * Typically this will be a value between 0.0 and 1.0.
     *
     * @param   
     *      The input value for learning.
     * @param   target
     *      Target to compute the update for.
     * @param   input
     *      Input to use in computing the update.
     * @param   actualCategory
     *      The actual category of the input.
     * @param   predicted
     *      The predicted category of the input.
     * @param   update
     *      The value from the computeUpdate step.
     * @return
     *      The decay to apply to the weights. Usually between 0.0 and
     *      1.0.
     */
    protected  double computeDecay(
        final DefaultKernelBinaryCategorizer target,
        final InputType input,
        final boolean actualCategory,
        final double predicted,
        final double update)
    {
        // Default is no decay.
        return 1.0;
    }

    /**
     * Computes the rescaling for the new weights. Can be overridden.
     * The default implementation just returns 1.0, which means no change.
     * Typically this will be a value between 0.0 and 1.0.
     *
     * @param   
     *      The input value for learning.
     * @param   target
     *      Target to compute the update for.
     * @param   input
     *      Input to use in computing the update.
     * @param   actualCategory
     *      The actual category of the input.
     * @param   predicted
     *      The predicted category of the input.
     * @param   update
     *      The value from the computeUpdate step.
     * @param   decay
     *      The value from the computeDecay step.
     * @return
     *      The rescaling to apply to the weights. Usually between 0.0 and
     *      1.0.
     */
    protected  double computeRescaling(
        final DefaultKernelBinaryCategorizer target,
        final InputType input,
        final boolean actualCategory,
        final double predicted,
        final double update,
        final double decay)
    {
        // Default is no rescaling.
        return 1.0;
    }

    /**
     * Gets whether or not the algorithm is updating the bias.
     *
     * @return
     *      True if the algorithm is updating the bias. Otherwise, false.
     */
    public boolean isUpdateBias()
    {
        return this.updateBias;
    }

    /**
     * Sets whether or not the algorithm is updating the bias.
     *
     * @param   updateBias
     *      True if the algorithm is updating the bias. Otherwise, false.
     */
    protected void setUpdateBias(
        final boolean updateBias)
    {
        this.updateBias = updateBias;
    }
    
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy