All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.ensemble.IVotingCategorizerLearner Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                IVotingCategorizerLearner.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright November 25, 2009, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive 
 * license for use of this work by or on behalf of the U.S. Government. Export 
 * of this program may require a license from the United States Government. 
 * See CopyrightHistory.txt for complete details.
 * 
 */

package gov.sandia.cognition.learning.algorithm.ensemble;

import gov.sandia.cognition.algorithm.event.AbstractIterativeAlgorithmListener;
import gov.sandia.cognition.algorithm.IterativeAlgorithm;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.collection.FiniteCapacityBuffer;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.factory.Factory;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.BatchLearnerContainer;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.UnivariateStatisticsUtil;
import gov.sandia.cognition.statistics.DataDistribution;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Randomized;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Random;

/**
 * Learns an ensemble in a method similar to bagging except that on each
 * iteration the bag is built from two parts, each sampled from elements from
 * disjoint sets. The two sets are the set of examples that the ensemble
 * currently gets correct and incorrect.
 *
 * In effect, ivoting has similar properties to boosting except that it does
 * not require that the learner for each ensemble member be able to use
 * weights on the examples.
 *
 * @param   
 *      The type of the input for the categorizer to learn. This is the type
 *      passed to the internal batch learner to learn each ensemble member.
 * @param   
 *      The type of the category that is the output for the categorizer to
 *      learn. It is also passed to the internal batch learner to learn each
 *      ensemble member. It must have a valid equals and hashCode method.
 * @author  Justin Basilico
 * @since   3.0
 */
@PublicationReference(
  author="Leo Breiman",
  title="Pasting small votes for classification in large databases and on-line",
  year=1999,
  type=PublicationType.Journal,
  publication="Machine Learning",
  pages={85, 103},
  url="http://www.springerlink.com/content/mnu2r28218651707/fulltext.pdf"
)
public class IVotingCategorizerLearner
    extends AbstractAnytimeSupervisedBatchLearner>>
    implements Randomized,
        BatchLearnerContainer>, ? extends Evaluator>>
{
    /** The default maximum number of iterations is {@value}. */
    public static final int DEFAULT_MAX_ITERATIONS = 100;

    /** The default percent to sample {@value}. */
    public static final double DEFAULT_PERCENT_TO_SAMPLE = 0.10;

    /** By default use 50% incorrect (and 50%) correct in the percent to sample. */
    public static final double DEFAULT_PROPORTION_INCORRECT_IN_SAMPLE = 0.50;

    /** The default value to vote out-of-bag. */
    public static final boolean DEFAULT_VOTE_OUT_OF_BAG_ONLY = true;

    /** The learner used to produce each ensemble member. */
    protected BatchLearner>,
        ? extends Evaluator> learner;

    /** The percent to sample on each iteration. */
    protected double percentToSample;

    /** The proportion of incorrect examples in each sample. Must be between 0.0
     *  and 1.0 (inclusive). */
    protected double proportionIncorrectInSample;

    /** Controls whether or not an ensemble member can vote on items it was
     *  trained on during learning. By default, the ensemble member can only
     *  vote on out-of-bag values. */
    protected boolean voteOutOfBagOnly;

    /** Factory for counting votes. */
    protected Factory> counterFactory;

    /** The random number generator to use. */
    protected Random random;

    /** The current ensemble. */
    protected transient WeightedVotingCategorizerEnsemble> ensemble;

    /** The data represented as an array list. */
    protected transient ArrayList> dataList;

    /** The running estimate of the ensemble for each example. Updated in each
     *  iteration with the ensemble member created for that iteration. This is
     *  used instead of evaluating the ensemble in each iteration to make it so
     *  that each ensemble member is only evaluated once on each training
     *  example.
     */
    protected transient ArrayList> dataFullEstimates;

    /** The running estimate of the ensemble for each example where an ensemble
     *  member can only vote on elements that were not in the bag used to train
     * it. */
    protected transient ArrayList> dataOutOfBagEstimates;

    /** A boolean for each example indicating if the ensemble currently gets the
     *  example correct or incorrect.
     */
    protected transient boolean[] currentEnsembleCorrect;

    /** The indices of examples that the ensemble currently gets correct. */
    protected transient ArrayList currentCorrectIndices;

    /** The indices of examples that the ensemble currently gets incorrect. */
    protected transient ArrayList currentIncorrectIndices;

    /** The size of sample to create on each iteration. */
    protected transient int sampleSize;

    /** The number of correct examples to sample on each iteration. */
    protected transient int numCorrectToSample;

    /** The number of incorrect examples to sample on each iteration. */
    protected transient int numIncorrectToSample;

    /** The current bag used to train the current ensemble member. */
    protected transient ArrayList> currentBag;

    /** A counter for each example indicating how many times it exists in the
     *  current bag.
     */
    protected transient int[] dataInBag;

    /** The currently learned member of the ensemble. */
    protected transient Evaluator currentMember;

    /** The estimates of the current member for each example.
     */
    protected transient ArrayList currentMemberEstimates;

    /**
     * Creates a new {@code IVotingCategorizerLearner}.
     */
    public IVotingCategorizerLearner()
    {
        this(null, DEFAULT_MAX_ITERATIONS, DEFAULT_PERCENT_TO_SAMPLE,
            new Random());
    }

    /**
     * Creates a new {@code IVotingCategorizerLearner}.
     *
     * @param  learner
     *      The learner to use to create the categorizer on each iteration.
     * @param  maxIterations
     *      The maximum number of iterations to run for, which is also the
     *      number of learners to create.
     * @param   percentToSample
     *      The percentage of the total size of the data to sample on each
     *      iteration. Must be positive.
     * @param  random
     *      The random number generator to use.
     */
    public IVotingCategorizerLearner(
        final BatchLearner>, ? extends Evaluator> learner,
        final int maxIterations,
        final double percentToSample,
        final Random random)
    {
        this(learner, maxIterations, percentToSample,
            DEFAULT_PROPORTION_INCORRECT_IN_SAMPLE,
            DEFAULT_VOTE_OUT_OF_BAG_ONLY,
            new DefaultDataDistribution.DefaultFactory(2),
            random);
    }

    /**
     * Creates a new {@code IVotingCategorizerLearner}.
     *
     * @param  learner
     *      The learner to use to create the categorizer on each iteration.
     * @param  maxIterations
     *      The maximum number of iterations to run for, which is also the
     *      number of learners to create.
     * @param   percentToSample
     *      The percentage of the total size of the data to sample on each
     *      iteration. Must be positive.
     * @param   proportionIncorrectInSample
     *      The percentage of incorrect examples to put in each sample. Must
     *      be between 0.0 and 1.0 (inclusive).
     * @param  voteOutOfBagOnly
     *      Controls whether or not in-bag or out-of-bag votes are used to
     *      determine accuracy.
     * @param   counterFactory
     *      The factory for counting votes.
     * @param  random
     *      The random number generator to use.
     */
    public IVotingCategorizerLearner(
        final BatchLearner>, ? extends Evaluator> learner,
        final int maxIterations,
        final double percentToSample,
        final double proportionIncorrectInSample,
        final boolean voteOutOfBagOnly,
        final Factory> counterFactory,
        final Random random)
    {
        super(maxIterations);

        this.setLearner(learner);
        this.setPercentToSample(percentToSample);
        this.setProportionIncorrectInSample(proportionIncorrectInSample);
        this.setVoteOutOfBagOnly(voteOutOfBagOnly);
        this.setCounterFactory(counterFactory);
        this.setRandom(random);
    }

    @Override
    protected boolean initializeAlgorithm()
    {
        final int dataSize = this.data.size();

        if (dataSize <= 0)
        {
            // Can't run on empty data.
            return false;
        }

        if (this.random == null)
        {
            this.random = new Random();
        }

        // Initialize the ensemble.
        this.ensemble = new WeightedVotingCategorizerEnsemble>(
            DatasetUtil.findUniqueOutputs(this.data));

        // We need the data to be as a random-access list in order to
        // efficiently sample from it.
        this.dataList = CollectionUtil.asArrayList(this.data);

        // We keep a running estimate of the output of the ensemble by
        // evaluating each member once.
        this.dataFullEstimates =
            new ArrayList>(dataSize);
        this.dataOutOfBagEstimates =
            new ArrayList>(dataSize);

        // We keep track of whether or not the current ensemble gets each
        // example correct.
        this.currentEnsembleCorrect = new boolean[dataSize];

        // We keep track of the lists of the indices of the current correct
        // and incorrect examples.
        this.currentCorrectIndices = new ArrayList(dataSize);
        this.currentIncorrectIndices = new ArrayList(dataSize);

        // Initialize the data estimates and the incorrect indices. All
        // examples are incorrect to start.
        for (int i = 0; i < dataSize; i++)
        {
            this.dataFullEstimates.add(new DefaultDataDistribution(2));
            this.dataOutOfBagEstimates.add(new DefaultDataDistribution(2));
            this.dataOutOfBagEstimates.add(this.counterFactory.create());
            this.currentIncorrectIndices.add(i);
        }

        // Figure out how many total samples and then the numbers of correct
        // and incorrect samples..
        this.sampleSize =
            Math.max(1, (int) (this.percentToSample * dataSize));
        this.numIncorrectToSample =
            (int) (this.proportionIncorrectInSample * sampleSize);
        this.numCorrectToSample = sampleSize - this.numIncorrectToSample;

        // Initialize the current bag.
        this.currentBag = new ArrayList>(
            this.numCorrectToSample + this.numIncorrectToSample);
        this.dataInBag = new int[dataSize];

        // Initialize the information about the current member.
        this.currentMember = null;
        this.currentMemberEstimates = new ArrayList(dataSize);

        // We initialize the array list.
        for (int i = 0; i < dataSize; i++)
        {
            this.currentMemberEstimates.add(null);
        }

        return true;
    }

    @Override
    protected boolean step()
    {
        final int dataSize = this.dataList.size();

        // Clear out the data from the previous iteration.
        this.currentBag.clear();

        for (int i = 0; i < dataSize; i++)
        {
            this.dataInBag[i] = 0;
        }

        // We use pointers here because we may have to change the pointers
        // after we make the lists in case one of them is empty.
        ArrayList correctIndices = this.currentCorrectIndices;
        ArrayList incorrectIndices = this.currentIncorrectIndices;

        // If one of the two lists is empty, we just point both pointers to the
        // same list, which means that the sampling steps are just going to be
        // sampling from the same distribution.
        if (incorrectIndices.isEmpty())
        {
            incorrectIndices = correctIndices;
        }
        else if (correctIndices.isEmpty())
        {
            correctIndices = incorrectIndices;
        }

        // Clear out the bag from the previous iteration.
        this.currentBag.clear();

        // Sample into the bag.
        this.createBag(correctIndices, incorrectIndices);

        // Learn the ensemble member.
        this.currentMember = this.learner.learn(this.currentBag);

        // Add a new member to the ensemble. It gets a default weight of
        // 1.0.
        this.ensemble.add(this.currentMember, 1.0);


        // Clear out the correct and incorrect information.
        this.currentCorrectIndices.clear();
        this.currentIncorrectIndices.clear();

        // Go through all the data and update the estimates for it. We keep
        // track of both the estimates for the new member and for the ensemble
        // as a whole. Incrementally computing the estimates saves a lot of
        // execution time since each member only sees each data item once.
        for (int i = 0; i < dataSize; i++)
        {
            // Get the example.
            final InputOutputPair example =
                this.dataList.get(i);
            final CategoryType actual = example.getOutput();

            // Get the guess for the new member.
            final CategoryType memberGuess = 
                this.currentMember.evaluate(example.getInput());

            // Save the estimate for the current member.
            this.currentMemberEstimates.set(i, memberGuess);

            // Get the full ensemble estimate for the current item.
            final DataDistribution fullEstimate =
                this.dataFullEstimates.get(i);

            // Get the out-of-bag estimate for the current item.
            final DataDistribution outOfBagEstimate =
                this.dataOutOfBagEstimates.get(i);
            
            if (memberGuess != null)
            {
                // Update the full estimate.
                fullEstimate.increment(memberGuess);

                if (this.dataInBag[i] <= 0)
                {
                    // Add to the out-of-bag estimate for the item.
                    outOfBagEstimate.increment(memberGuess);
                }
            }
            // else - The member had nothing to contribute for the estimate.

            // Get the new ensemble guess.
            CategoryType ensembleGuess = null;

            // See if we're guessing based on out-of-bag only.
            if (this.voteOutOfBagOnly && outOfBagEstimate.getTotal() > 0)
            {
                ensembleGuess = outOfBagEstimate.getMaxValueKey();
            }
            else
            {
                // Either we're not doing out-of-bag or there are no out-of-bag
                // votes for this item, so we use the full ensemble estimate.
                ensembleGuess = fullEstimate.getMaxValueKey();
            }

            // We assume that the ensemble member is correct in the case that
            // out-of-bag-only counting is turned on and there are no votes for
            // an item. This case means that the item appears in the bag of
            // used to train every member so far. Thus, we assume that the
            // ensemble is getting it correct.
            // Otherwise, equality of the actual and the guess is used.
            final boolean ensembleCorrect = ensembleGuess == null
                || ObjectUtil.equalsSafe(actual, ensembleGuess);

            this.currentEnsembleCorrect[i] = ensembleCorrect;

            // Update the list of correct and incorrect indices.
            if (ensembleCorrect)
            {
                this.currentCorrectIndices.add(i);
            }
            else
            {
                this.currentIncorrectIndices.add(i);
            }
        }

        return true;
    }

    /**
     * Create the next sample (bag) of examples to learn the next ensemble
     * member from.
     *
     * @param   correctIndices
     *      The list of indices the ensemble is currently getting correct.
     * @param   incorrectIndices
     *      The list of indices the ensemble is currently getting incorrect.
     */
    protected void createBag(
        final ArrayList correctIndices,
        final ArrayList incorrectIndices)
    {
        sampleIndicesWithReplacementInto(correctIndices, this.dataList,
            this.numCorrectToSample, this.random,
            this.currentBag, this.dataInBag);
        sampleIndicesWithReplacementInto(incorrectIndices, this.dataList,
            this.numIncorrectToSample, this.random,
            this.currentBag, this.dataInBag);
    }
    
    /**
     * Takes the given number of samples from the given list and places them in
     * the given output list. It samples with replacement, which means that a
     * given item may appear multiple times in the bag. It also keeps track of
     * how many times each item was sampled.
     *
     * @param   
     *      The data type to sample.
     * @param   fromIndices
     *      The indices into the given base data to sample from.
     * @param   baseData
     *      The list to sample from using the given list of indices.
     * @param   numToSample
     *      The number to sample. Must be non-negative.
     * @param   random
     *      The random number generator to use.
     * @param   output
     *      The list to add the samples to.
     * @param   dataInBag
     *      The array of counters for the number of times each example is
     *      sampled.
     */
    protected static  void sampleIndicesWithReplacementInto(
        final ArrayList fromIndices,
        final ArrayList baseData,
        final int numToSample,
        final Random random,
        final ArrayList output,
        final int[] dataInBag)
    {
        // Do the sampling with replacement, which means duplicates are allowed.
        final int fromSize = fromIndices.size();
        for (int i = 0; i < numToSample; i++)
        {
            // Randomly pick an item.
            final int randomInt = random.nextInt(fromSize);

            // Get the index for the random item.
            final int index = fromIndices.get(randomInt);

            // Add the item to the bag.
            output.add(baseData.get(index));

            // Increse the bag counter for the chosen item.
            dataInBag[index] += 1;
        }
    }

    @Override
    protected void cleanupAlgorithm()
    {
        // Clear out the cached values.
        this.dataList = null;
        this.dataFullEstimates = null;
        this.dataOutOfBagEstimates = null;
        this.dataInBag = null;
        this.currentMember = null;
        this.currentCorrectIndices = null;
        this.currentIncorrectIndices = null;
        this.currentBag = null;
        this.currentEnsembleCorrect = null;
        this.currentMemberEstimates = null;
    }

    public WeightedVotingCategorizerEnsemble>
        getResult()
    {
        return this.ensemble;
    }
    
    /**
     * Gets the learner used to learn each ensemble member.
     *
     * @return
     *      The learner used for each ensemble member.
     */
    public BatchLearner>, ? extends Evaluator>
        getLearner()
    {
        return this.learner;
    }

    /**
     * Sets the learner used to learn each ensemble member. Must be a supervised
     * learning algorithm that takes in a collection of input-output pairs of
     * the given data types and produces an evaluator for those data types.
     *
     * @param   learner
     *      The learner used for each ensemble member.
     */
    public void setLearner(
        final BatchLearner>, ? extends Evaluator> learner)
    {
        this.learner = learner;
    }

    /**
     * Gets the percentage of the total data to sample on each iteration.
     *
     * @return
     *      The percentage of the total data to sample on each iteration.
     */
    public double getPercentToSample()
    {
        return this.percentToSample;
    }

    /**
     * Sets the percentage of the data to sample (with replacement) on each
     * iteration. Must be greater than zero. The percent is represented as a
     * floating point number with 1.0 representing 100%.
     *
     * @param   percentToSample
     *      The percent of the data to sample on each iteration. Must be greater
     *      than zero. Defaults to 100%.
     */
    public void setPercentToSample(
        final double percentToSample)
    {
        if (percentToSample <= 0.0)
        {
            throw new IllegalArgumentException(
                "percentToSample must be greater than zero.");
        }

        this.percentToSample = percentToSample;
    }

    /**
     * Gets the proportion of incorrect examples to place in each sample.
     *
     * @return
     *      The proportion of incorrect examples in each sample.
     */
    public double getProportionIncorrectInSample()
    {
        return this.proportionIncorrectInSample;
    }

    /**
     * Sets the proportion of incorrect examples to place in each sample. Must
     * be between 0.0 and 1.0 (inclusive). The rest of the examples in the
     * sample will be filled from the correct examples.
     *
     * @param   proportionIncorrectInSample
     *      The proportion of incorrect examples in each sample. Must be between
     *      0.0 and 1.0 (inclusive).
     */
    public void setProportionIncorrectInSample(
        final double proportionIncorrectInSample)
    {
        if (    proportionIncorrectInSample < 0.0
             || proportionIncorrectInSample > 1.0)
        {
            throw new IllegalArgumentException(
                "proportionIncorrectInSample must be between 0.0 and 1.0 (inclusive).");
        }
        this.proportionIncorrectInSample = proportionIncorrectInSample;
    }

    /**
     * Gets whether during learning ensemble members can only vote on items
     * that they are not in their bag (training set).
     *
     * @return
     *      If out-of-bag-only voting is enabled.
     */
    public boolean isVoteOutOfBagOnly()
    {
        return voteOutOfBagOnly;
    }

    /**
     * Sets whether during learning ensemble members can only vote on items
     * that they are not in their bag (training set). In the vast majority
     * of cases, this should be enabled. It is enabled by default.
     *
     * @param   voteOutOfBagOnly
     *      If out-of-bag-only voting should be enabled.
     */
    public void setVoteOutOfBagOnly(
        final boolean voteOutOfBagOnly)
    {
        this.voteOutOfBagOnly = voteOutOfBagOnly;
    }

    /**
     * Gets the factory used for creating the object for counting the votes of
     * the learned ensemble members.
     *
     * @return
     *      The factory used to create the vote counting objects.
     */
    public Factory> getCounterFactory()
    {
        return this.counterFactory;
    }
    
    /**
     * Sets the factory used for creating the object for counting the votes of
     * the learned ensemble members.
     *
     * @param   counterFactory
     *      The factory used to create the vote counting objects.
     */
    public void setCounterFactory(
        final Factory> counterFactory)
    {
        this.counterFactory = counterFactory;
    }

    @Override
    public Random getRandom()
    {
        return this.random;
    }

    @Override
    public void setRandom(
        final Random random)
    {
        this.random = random;
    }

    /**
     * Gets the current estimates for each data point. Do not modify these
     * counts as they will change the algorithm.
     *
     * @return
     *      The current estimates for each data point.
     */
    public List> getDataFullEstimates()
    {
        return Collections.unmodifiableList(this.dataFullEstimates);
    }

    /**
     * Gets the current out-of-bag estimates for each data point. Do not modify
     * these counts as they will change the algorithm.
     *
     * @return
     *      The current out-of-bag estimates for each data point.
     */
    public List> getDataOutOfBagEstimates()
    {
        return Collections.unmodifiableList(this.dataOutOfBagEstimates);
    }

    /**
     * Gets whether or not the current ensemble gets each example correct.
     * Do not modify these values, as they will change the algorithm.
     *
     * @return
     *      The array of booleans regarding whether or not the ensemble gets
     *      an example correct.
     */
    public boolean[] getCurrentEnsembleCorrect()
    {
        return this.currentEnsembleCorrect;
    }

    /**
     * Implements a stopping criteria for IVoting that uses the out-of-bag
     * error to determine when to stop learning the ensemble. It tracks the
     * out-of-bag error rate of the ensemble and keeps it in a given smoothing
     * window. Once the smoothed error rate stops decreasing, it stops learning
     * and removes all of the ensemble members back to the one that had the
     * minimal error in that window.
     *
     * @param 
     *      The input type the algorithm is learning over.
     * @param 
     *      The category type the algorithm is learning over.
     */
    public static class OutOfBagErrorStoppingCriteria
        extends AbstractIterativeAlgorithmListener
    {
// TODO: Implement a look-ahead capability.
// -- jdbasil (2011-02-18)
        /** The default smoothing window size is {@value}. */
        public static final int DEFAULT_SMOOTHING_WINDOW_SIZE = 25;

        /** The size of window of data to look at to determine if learning has
         *  hit a minimum. */
        protected int smoothingWindowSize;

        /** The learner the stopping criteria is for. */
        protected transient IVotingCategorizerLearner
            learner;

        /** A boolean for each example indicating whether or not it is
         *  currently a correct or incorrect out-of-bag vote. This should be
         *  the same size as the collection of data. */
        protected transient boolean[] outOfBagCorrect;

        /** The total number of out-of-bag errors. This should equal the number
         *  of false values in the outOfBagCorrect array. */
        protected transient int outOfBagErrorCount;

        /** The raw out-of-bag error rate, per iteration. */
        protected transient ArrayList rawErrorRates;

        /** The smoothed out-of-bag error rates, per iteration. */
        protected transient ArrayList smoothedErrorRates;

        /** The buffer used for smoothing. */
        protected transient FiniteCapacityBuffer smoothingBuffer;

        /** The smoothed error rate of the previous iteration. */
        protected transient double previousSmoothedErrorRate;

        /**
         * Creates a new {@code OutOfBagErrorStoppingCriteria}.
         */
        public OutOfBagErrorStoppingCriteria()
        {
            this(DEFAULT_SMOOTHING_WINDOW_SIZE);
        }

        /**
         * Creates a new {@code OutOfBagErrorStoppingCriteria} with the given
         * smoothing window size.
         *
         * @param   smoothingWindowSize
         *      The smoothing window size to use. Must be positive.
         */
        public OutOfBagErrorStoppingCriteria(
            final int smoothingWindowSize)
        {
            super();

            this.setSmoothingWindowSize(smoothingWindowSize);
        }

        @SuppressWarnings("unchecked")
        @Override
        public void algorithmStarted(
            final IterativeAlgorithm algorithm)
        {
            this.learner = (IVotingCategorizerLearner)
                algorithm;
            final int dataSize = this.learner.data.size();
            this.outOfBagCorrect = new boolean[dataSize];
            this.outOfBagErrorCount = dataSize;
            this.rawErrorRates = new ArrayList();
            this.smoothedErrorRates = new ArrayList();
            this.smoothingBuffer = new FiniteCapacityBuffer(
                this.smoothingWindowSize);
            this.previousSmoothedErrorRate = Double.MAX_VALUE;
        }

        @Override
        public void algorithmEnded(
            final IterativeAlgorithm algorithm)
        {
            this.learner = null;
            this.outOfBagCorrect = null;
            this.rawErrorRates = null;
            this.smoothedErrorRates = null;
            this.smoothingBuffer = null;
        }

        @Override
        public void stepEnded(
            final IterativeAlgorithm algorithm)
        {
            // Go through the data and update the values for the data that was
            // not in the bag.
            final int dataSize = learner.data.size();
            for (int i = 0; i < dataSize; i++)
            {
                if (learner.dataInBag[i] <= 0)
                {
                    // Get the actual category.
                    final CategoryType actual =
                        learner.dataList.get(i).getOutput();
                    
                    // Get the out-of-bag-votes to determine the ensemble's
                    // guess.
                    final DataDistribution outOfBagVotes =
                        learner.dataOutOfBagEstimates.get(i);
                    final CategoryType ensembleGuess =
                        outOfBagVotes.getMaxValueKey();

                    // Update whether or not the ensemble is getting this item
                    // correct.
                    final boolean oldEnsembleCorrect = this.outOfBagCorrect[i];
                    final boolean newEnsembleCorrect = 
                        ObjectUtil.equalsSafe(actual, ensembleGuess);

                    if (oldEnsembleCorrect != newEnsembleCorrect)
                    {
                        // Save the new correctness.
                        this.outOfBagCorrect[i] = newEnsembleCorrect;

                        // Update the error count.
                        if (newEnsembleCorrect)
                        {
                            this.outOfBagErrorCount--;
                        }
                        else
                        {
                            this.outOfBagErrorCount++;
                        }
                    }

                }
            }

            // Compute the out-of-bag error rate for the ensemble.
            final double outOfBagEnsembleErrorRate = 
                (double) this.outOfBagErrorCount / this.learner.data.size();

            // Store this and compute the smoothed error rate.
            this.rawErrorRates.add(outOfBagEnsembleErrorRate);
            this.smoothingBuffer.add(outOfBagEnsembleErrorRate);
            final double smoothedErrorRate =
                UnivariateStatisticsUtil.computeMean(this.smoothingBuffer);
            this.smoothedErrorRates.add(smoothedErrorRate);

            // See if the algorithm is still making progress or not. Once the
            // smoothed error rate stops improving, its time to stop.
            if (smoothedErrorRate >= this.previousSmoothedErrorRate)
            {
                // Stop the learning since it is no longer improving.
                this.learner.stop();

                // Now we need to figure out where the ensemble had the best
                // performance from the smoothing buffer
                final int ensembleSize = this.rawErrorRates.size();
                int bestIndex = 0;
                double bestRawErrorRate = Double.MAX_VALUE;
                for (int i = 0; i < this.smoothingBuffer.size(); i++)
                {
                    // Walk the ensemble backwards.
                    final int index = ensembleSize - i - 1;
                    final double rawErrorRate = this.rawErrorRates.get(index);
                    if (rawErrorRate <= bestRawErrorRate)
                    {
                        bestIndex = index;
                        bestRawErrorRate = rawErrorRate;
                    }
                }

                // Now that we know which index was the best, we need to rewind
                // the ensemble to that point by removing all the members
                // added after it.
                for (int i = ensembleSize - 1; i > bestIndex; i--)
                {
                    learner.ensemble.members.remove(i);
                }
            }

            // Save the smoothed error rate.
            this.previousSmoothedErrorRate = smoothedErrorRate;

        }

        /**
         * Gets the size of the smoothing window.
         *
         * @return
         *      The size of the smoothing window.
         */
        public int getSmoothingWindowSize()
        {
            return smoothingWindowSize;
        }

        /**
         * Sets the smoothing window size.
         *
         * @param   smoothingWindowSize
         *      The smoothing window size. Must be positive.
         */
        public void setSmoothingWindowSize(
            final int smoothingWindowSize)
        {
            if (smoothingWindowSize < 0)
            {
                throw new IllegalArgumentException(
                    "smoothingWindowSize must be positive.");
            }

            this.smoothingWindowSize = smoothingWindowSize;
        }


    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy