All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.ensemble.AbstractCategorizerOutOfBagStoppingCriteria Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:            AbstractCategorizerOutOfBagStoppingCriteria.java
 * Authors:         Justin Basilico
 * Project:         Cognitive Foundry Learning Core
 * 
 * Copyright 2017 Cognitive Foundry. All rights reserved.
 */

package gov.sandia.cognition.learning.algorithm.ensemble;

import gov.sandia.cognition.algorithm.IterativeAlgorithm;
import gov.sandia.cognition.algorithm.event.AbstractIterativeAlgorithmListener;
import gov.sandia.cognition.collection.FiniteCapacityBuffer;
import gov.sandia.cognition.math.UnivariateStatisticsUtil;
import gov.sandia.cognition.statistics.DataDistribution;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;

/**
 * Abstract class for implementing a out-of-bag stopping criteria for a
 * bagging-based ensemble.
 * 
 * @param   
 *      The type of the input for the categorizer to learn.
 * @param   
 *      The type of the category that is the output for the categorizer to
 *      learn.
 * 
 * @author  Justin Basilico
 * @since   4.0.0
 */
public abstract class AbstractCategorizerOutOfBagStoppingCriteria
    extends AbstractIterativeAlgorithmListener
{
// TODO: Implement a look-ahead capability.
// -- jdbasil (2011-02-18)
    /** The default smoothing window size is {@value}. */
    public static final int DEFAULT_SMOOTHING_WINDOW_SIZE = 25;

    /** The size of window of data to look at to determine if learning has
     *  hit a minimum. */
    protected int smoothingWindowSize;

    /** The learner the stopping criteria is for. */
    protected transient BagBasedCategorizerEnsembleLearner
        learner;

    /** A boolean for each example indicating whether or not it is
     *  currently a correct or incorrect out-of-bag vote. This should be
     *  the same size as the collection of data. */
    protected transient boolean[] outOfBagCorrect;

    /** The total number of out-of-bag errors. This should equal the number
     *  of false values in the outOfBagCorrect array. */
    protected transient int outOfBagErrorCount;

    /** The raw out-of-bag error rate, per iteration. */
    protected transient ArrayList rawErrorRates;

    /** The smoothed out-of-bag error rates, per iteration. */
    protected transient ArrayList smoothedErrorRates;

    /** The buffer used for smoothing. */
    protected transient FiniteCapacityBuffer smoothingBuffer;

    /** The smoothed error rate of the previous iteration. */
    protected transient double previousSmoothedErrorRate;

    /**
     * Creates a new {@code OutOfBagErrorStoppingCriteria}.
     */
    public AbstractCategorizerOutOfBagStoppingCriteria()
    {
        this(DEFAULT_SMOOTHING_WINDOW_SIZE);
    }

    /**
     * Creates a new {@code OutOfBagErrorStoppingCriteria} with the given
     * smoothing window size.
     *
     * @param   smoothingWindowSize
     *      The smoothing window size to use. Must be positive.
     */
    public AbstractCategorizerOutOfBagStoppingCriteria(
        final int smoothingWindowSize)
    {
        super();

        this.setSmoothingWindowSize(smoothingWindowSize);
    }

    @SuppressWarnings("unchecked")
    @Override
    public void algorithmStarted(
        final IterativeAlgorithm algorithm)
    {
        this.learner = (BagBasedCategorizerEnsembleLearner)
            algorithm;
        final int dataSize = this.learner.getData().size();
        this.outOfBagCorrect = new boolean[dataSize];
        this.outOfBagErrorCount = dataSize;
        this.rawErrorRates = new ArrayList<>();
        this.smoothedErrorRates = new ArrayList<>();
        this.smoothingBuffer = new FiniteCapacityBuffer<>(
            this.smoothingWindowSize);
        this.previousSmoothedErrorRate = Double.MAX_VALUE;
    }

    @Override
    public void algorithmEnded(
        final IterativeAlgorithm algorithm)
    {
        this.learner = null;
        this.outOfBagCorrect = null;
        this.rawErrorRates = null;
        this.smoothedErrorRates = null;
        this.smoothingBuffer = null;
    }

    /**
     * Gets the out-of-bag estimate distribution across categories for the
     * training example with the given index.
     * 
     * @param index
     *      The 0-based index for the training example.
     * @return 
     *      The distribution over output categories for the out-of-bag
     *      estimate for the example at that index. May be empty.
     */
    public abstract DataDistribution getOutOfBagEstimate(
        final int index);
    
    @Override
    public void stepEnded(
        final IterativeAlgorithm algorithm)
    {
        // Go through the data and update the values for the data that was
        // not in the bag.
        final int dataSize = learner.getData().size();
        final int[] dataInBag = learner.getDataInBag();
        
        for (int i = 0; i < dataSize; i++)
        {
            if (dataInBag[i] <= 0)
            {
                // Get the actual category.
                final CategoryType actual =
                    learner.getExample(i).getOutput();

                // Get the out-of-bag-votes to determine the ensemble's
                // guess.
                final DataDistribution outOfBagVotes =
                    this.getOutOfBagEstimate(i);
                final CategoryType ensembleGuess =
                    outOfBagVotes.getMaxValueKey();

                // Update whether or not the ensemble is getting this item
                // correct.
                final boolean oldEnsembleCorrect = this.outOfBagCorrect[i];
                final boolean newEnsembleCorrect = 
                    ObjectUtil.equalsSafe(actual, ensembleGuess);

                if (oldEnsembleCorrect != newEnsembleCorrect)
                {
                    // Save the new correctness.
                    this.outOfBagCorrect[i] = newEnsembleCorrect;

                    // Update the error count.
                    if (newEnsembleCorrect)
                    {
                        this.outOfBagErrorCount--;
                    }
                    else
                    {
                        this.outOfBagErrorCount++;
                    }
                }

            }
        }

        // Compute the out-of-bag error rate for the ensemble.
        final double outOfBagEnsembleErrorRate = 
            (double) this.outOfBagErrorCount / this.learner.getData().size();

        // Store this and compute the smoothed error rate.
        this.rawErrorRates.add(outOfBagEnsembleErrorRate);
        this.smoothingBuffer.add(outOfBagEnsembleErrorRate);
        final double smoothedErrorRate =
            UnivariateStatisticsUtil.computeMean(this.smoothingBuffer);
        this.smoothedErrorRates.add(smoothedErrorRate);

        // See if the algorithm is still making progress or not. Once the
        // smoothed error rate stops improving, its time to stop.
        if (smoothedErrorRate >= this.previousSmoothedErrorRate)
        {
            // Stop the learning since it is no longer improving.
            this.learner.stop();

            // Now we need to figure out where the ensemble had the best
            // performance from the smoothing buffer
            final int ensembleSize = this.rawErrorRates.size();
            int bestIndex = 0;
            double bestRawErrorRate = Double.MAX_VALUE;
            for (int i = 0; i < this.smoothingBuffer.size(); i++)
            {
                // Walk the ensemble backwards.
                final int index = ensembleSize - i - 1;
                final double rawErrorRate = this.rawErrorRates.get(index);
                if (rawErrorRate <= bestRawErrorRate)
                {
                    bestIndex = index;
                    bestRawErrorRate = rawErrorRate;
                }
            }

            // Now that we know which index was the best, we need to rewind
            // the ensemble to that point by removing all the members
            // added after it.
            for (int i = ensembleSize - 1; i > bestIndex; i--)
            {
                learner.getResult().members.remove(i);
            }
        }

        // Save the smoothed error rate.
        this.previousSmoothedErrorRate = smoothedErrorRate;

    }

    /**
     * Gets the size of the smoothing window.
     *
     * @return
     *      The size of the smoothing window.
     */
    public int getSmoothingWindowSize()
    {
        return this.smoothingWindowSize;
    }

    /**
     * Sets the smoothing window size.
     *
     * @param   smoothingWindowSize
     *      The smoothing window size. Must be positive.
     */
    public void setSmoothingWindowSize(
        final int smoothingWindowSize)
    {
        if (smoothingWindowSize < 0)
        {
            throw new IllegalArgumentException(
                "smoothingWindowSize must be positive.");
        }

        this.smoothingWindowSize = smoothingWindowSize;
    }

}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy