All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.ensemble.AdaBoost Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                AdaBoost.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Framework Lite
 *
 * Copyright September 19, 2007, Sandia Corporation.  Under the terms of Contract
 * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
 * or on behalf of the U.S. Government. Export of this program may require a
 * license from the United States Government. See CopyrightHistory.txt for
 * complete details.
 *
 */

package gov.sandia.cognition.learning.algorithm.ensemble;

import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.BatchLearnerContainer;
import gov.sandia.cognition.learning.data.DefaultWeightedInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.data.WeightedInputOutputPair;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;

/**
 * The {@code AdaBoost} class implements the Adaptive Boosting (AdaBoost)
 * algorithm formulated by Yoav Freund and Robert Shapire. It creates a 
 * weighted binary ensemble as output from the algorithm.
 *
 * @param  The algorithm operates on a
 * {@code Collection>}.  The
 * {@code InputType} will be something like Vector or String.
 * @author Justin Basilico
 * @since  2.0
 */
@CodeReview(
    reviewer="Kevin R. Dixon",
    date="2008-07-23",
    changesNeeded=false,
    comments={
        "Added PublicationReference",
        "Cleaned up javadoc a little bit with code annotations.",
        "Otherwise, looks fine."
    }
)
@PublicationReference(
    author={
        "Yoav Freund",
        "Robert E.Schapire"
    },
    title="A decision-theoretic generalization of on-line learning and an application to boosting",
    publication="Journal of Computer and System Sciences",
    notes="Volume 55, Number 1",
    year=1997,
    pages={119,139},
    type=PublicationType.Journal,
    url="http://www.cse.ucsd.edu/~yfreund/papers/adaboost.pdf"
)
public class AdaBoost
    extends AbstractAnytimeSupervisedBatchLearner>>
    implements BatchLearnerContainer>, ? extends Evaluator>>
{

    /** The default maximum number of iterations is {@value}. */
    public static final int DEFAULT_MAX_ITERATIONS = 100;

    /**
     * The "weak learner" that must learn from the weighted input-output pairs
     * on each iteration.
     */
    protected BatchLearner>, ? extends Evaluator>
        weakLearner;

    /** The ensemble learned by the algorithm. */
    protected transient WeightedBinaryEnsemble>
        ensemble;

    /** An array list containing the weighted version of the data. */
    protected transient ArrayList> weightedData;

    /**
     * Creates a new instance of AdaBoost with no base learner.
     */
    public AdaBoost()
    {
        this(null);
    }

    /**
     * Creates a new instance of AdaBoost.
     *
     * @param  weakLearner The weak learner to apply learning to on each 
     *         iteration.
     */
    public AdaBoost(
        final BatchLearner>, ? extends Evaluator> weakLearner)
    {
        this(weakLearner, DEFAULT_MAX_ITERATIONS);
    }

    /**
     * Creates a new instance of AdaBoost.
     *
     * @param  weakLearner 
     *      The weak learner to apply learning to on each iteration.
     * @param  maxIterations 
     *      The maximum number of iterations to run for, which is also the 
     *      upper bound on the number of learners to create.
     */
    public AdaBoost(
        final BatchLearner>, ? extends Evaluator> weakLearner,
        final int maxIterations)
    {
        super(maxIterations);

        this.setWeakLearner(weakLearner);

        this.setEnsemble(null);
        this.setWeightedData(null);
    }

    @Override
    protected boolean initializeAlgorithm()
    {
        if (this.getData() == null || this.getData().size() <= 0)
        {
            // This is an invalid dataset.
            return false;
        }

        // Create the ensemble.
        this.setEnsemble(new WeightedBinaryEnsemble>(
            new LinkedList>>()));

        // We initialize the weighted training examples and count them up
        // as we go so that we can initialize the weights in the next step.
        int numExamples = 0;
        this.setWeightedData(
            new ArrayList>(
            this.getData().size()));

        for (InputOutputPair example
            : this.getData())
        {
            if (example != null)
            {
                this.weightedData.add(
                    new DefaultWeightedInputOutputPair(
                        example, 1.0));
                numExamples++;
            }
        }

        // The initial weight is the same for all examples. We use this
        // value to ensure that they all sum to one.
        double initialWeight = 1.0 / numExamples;
        for (DefaultWeightedInputOutputPair example
            : this.weightedData)
        {
            example.setWeight(initialWeight);
        }

        return true;
    }

    @Override
    protected boolean step()
    {
        // First perform the weak learning algorithm.
        Evaluator learned =
            this.weakLearner.learn(this.weightedData);

        if (learned == null)
        {
            // Nothing was learned.
            return false;
        }

        // Now compute the weighted error for the weak learner.
        // Also, while we compute the error we will save the predictions
        // that are made by the learner, which we will need when we update
        // the weights.
        double error = 0.0;
        int numExamples = this.weightedData.size();
        double[] predictions = new double[numExamples];
        // Note: The predictions array could be set as a class member and
        // reused if it is thrashing memory.

        for (int i = 0; i < numExamples; i++)
        {
            WeightedInputOutputPair example =
                this.weightedData.get(i);

            // Compute the prediction of the learner for this instance.
            Boolean prediction = learned.evaluate(example.getInput());

            double predictionDouble = 0.0;
            if (prediction == null)
            {
                predictionDouble = 0.0;
            }
            else if (prediction)
            {
                predictionDouble = +1.0;
            }
            else // ( !prediction )
            {
                predictionDouble = -1.0;
            }

            predictions[i] = predictionDouble;

            // Get the actual output.
            double actual = example.getOutput() ? +1.0 : -1.0;

            if (actual * predictionDouble < 0.0)
            {
                // The prediction was incorrect so add it to the weighted
                // error.
                error += example.getWeight();
            }
        // else - The prediction is correct or it was abstained (0.0)
        }

        if (error >= 0.5)
        {
            // The error rate of this classifier is worse than random, so we
            // stop the algorithm.
            return false;
        }

        // Alpha is fixed at 1.0 for this implementation because the
        // output of the weak learner has no constraints beyond requiring
        // that it is a real number.
        // Note: we could revisit this value of alpha in the future.
        double alpha = 1.0;
        if (error > 0.0)
        {
            alpha = 0.5 * Math.log((1 - error) / error);
        }

        // Add this to the ensemble.
        this.ensemble.add(learned, alpha);

        // Now update the weight assigned to each training example.
        // Also, we need to make it so the weights sum to one. We do this
        // by summing them up in this loop and then dividing them in
        // second pass.
        double sum = 0.0;
        for (int i = 0; i < numExamples; i++)
        {
            DefaultWeightedInputOutputPair example =
                this.weightedData.get(i);
            double predicted = predictions[i];
            double actual = example.getOutput() ? +1.0 : -1.0;
            double oldWeight = example.getWeight();

            // Compute the new weight.
            double newWeight =
                oldWeight * Math.exp(-alpha * actual * predicted);

            // Set the weight and update the sum.
            example.setWeight(newWeight);
            sum += newWeight;
        }

        // Normalize the weights to be a distribution.
        for (DefaultWeightedInputOutputPair example
            : this.weightedData)
        {
            example.setWeight(example.getWeight() / sum);
        }

        return true;
    }

    @Override
    protected void cleanupAlgorithm()
    {
        // We no longer need the weighted data.
        this.setWeightedData(null);
    }

    @Override
    public WeightedBinaryEnsemble> getResult()
    {
        return this.getEnsemble();
    }

    @Override
    public BatchLearner>, ? extends Evaluator> getLearner()
    {
        return this.getWeakLearner();
    }

    /**
     * Gets the weak learner that is passed the weighted training data on each
     * step of the algorithm. 
     *
     * @return The weak learner for the algorithm to use.
     */
    public BatchLearner>, ? extends Evaluator> getWeakLearner()
    {
        return this.weakLearner;
    }

    /**
     * Sets the weak learner that is passed the weighted training data on each
     * step of the algorithm. 
     *
     * @param   weakLearner The weak learner for the algorithm to use.
     */
    public void setWeakLearner(
        final BatchLearner>, ? extends Evaluator> weakLearner)
    {
        this.weakLearner = weakLearner;
    }

    /**
     * Gets the ensemble created by this learner.
     *
     * @return The ensemble created by this learner.
     */
    public WeightedBinaryEnsemble> getEnsemble()
    {
        return this.ensemble;
    }

    /**
     * Sets the ensemble created by this learner.
     *
     * @param  ensemble The ensemble created by this learner.
     */
    protected void setEnsemble(
        final WeightedBinaryEnsemble> ensemble)
    {
        this.ensemble = ensemble;
    }

    /**
     * Gets the weighted version of the data.
     *
     * @return The weighted version of the data.
     */
    public ArrayList> getWeightedData()
    {
        return this.weightedData;
    }

    /**
     * Sets the weighted version of the data.
     *
     * @param  weightedData The weighted version of the data.
     */
    protected void setWeightedData(
        final ArrayList> weightedData)
    {
        this.weightedData = weightedData;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy