![JAR search and dependency download from the Maven repository](/logo.png)
gov.sandia.cognition.learning.algorithm.ensemble.AbstractCategorizerOutOfBagStoppingCriteria Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: AbstractCategorizerOutOfBagStoppingCriteria.java
* Authors: Justin Basilico
* Project: Cognitive Foundry Learning Core
*
* Copyright 2017 Cognitive Foundry. All rights reserved.
*/
package gov.sandia.cognition.learning.algorithm.ensemble;
import gov.sandia.cognition.algorithm.IterativeAlgorithm;
import gov.sandia.cognition.algorithm.event.AbstractIterativeAlgorithmListener;
import gov.sandia.cognition.collection.FiniteCapacityBuffer;
import gov.sandia.cognition.math.UnivariateStatisticsUtil;
import gov.sandia.cognition.statistics.DataDistribution;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
/**
* Abstract class for implementing a out-of-bag stopping criteria for a
* bagging-based ensemble.
*
* @param
* The type of the input for the categorizer to learn.
* @param
* The type of the category that is the output for the categorizer to
* learn.
*
* @author Justin Basilico
* @since 4.0.0
*/
public abstract class AbstractCategorizerOutOfBagStoppingCriteria
extends AbstractIterativeAlgorithmListener
{
// TODO: Implement a look-ahead capability.
// -- jdbasil (2011-02-18)
/** The default smoothing window size is {@value}. */
public static final int DEFAULT_SMOOTHING_WINDOW_SIZE = 25;
/** The size of window of data to look at to determine if learning has
* hit a minimum. */
protected int smoothingWindowSize;
/** The learner the stopping criteria is for. */
protected transient BagBasedCategorizerEnsembleLearner
learner;
/** A boolean for each example indicating whether or not it is
* currently a correct or incorrect out-of-bag vote. This should be
* the same size as the collection of data. */
protected transient boolean[] outOfBagCorrect;
/** The total number of out-of-bag errors. This should equal the number
* of false values in the outOfBagCorrect array. */
protected transient int outOfBagErrorCount;
/** The raw out-of-bag error rate, per iteration. */
protected transient ArrayList rawErrorRates;
/** The smoothed out-of-bag error rates, per iteration. */
protected transient ArrayList smoothedErrorRates;
/** The buffer used for smoothing. */
protected transient FiniteCapacityBuffer smoothingBuffer;
/** The smoothed error rate of the previous iteration. */
protected transient double previousSmoothedErrorRate;
/**
* Creates a new {@code OutOfBagErrorStoppingCriteria}.
*/
public AbstractCategorizerOutOfBagStoppingCriteria()
{
this(DEFAULT_SMOOTHING_WINDOW_SIZE);
}
/**
* Creates a new {@code OutOfBagErrorStoppingCriteria} with the given
* smoothing window size.
*
* @param smoothingWindowSize
* The smoothing window size to use. Must be positive.
*/
public AbstractCategorizerOutOfBagStoppingCriteria(
final int smoothingWindowSize)
{
super();
this.setSmoothingWindowSize(smoothingWindowSize);
}
@SuppressWarnings("unchecked")
@Override
public void algorithmStarted(
final IterativeAlgorithm algorithm)
{
this.learner = (BagBasedCategorizerEnsembleLearner)
algorithm;
final int dataSize = this.learner.getData().size();
this.outOfBagCorrect = new boolean[dataSize];
this.outOfBagErrorCount = dataSize;
this.rawErrorRates = new ArrayList<>();
this.smoothedErrorRates = new ArrayList<>();
this.smoothingBuffer = new FiniteCapacityBuffer<>(
this.smoothingWindowSize);
this.previousSmoothedErrorRate = Double.MAX_VALUE;
}
@Override
public void algorithmEnded(
final IterativeAlgorithm algorithm)
{
this.learner = null;
this.outOfBagCorrect = null;
this.rawErrorRates = null;
this.smoothedErrorRates = null;
this.smoothingBuffer = null;
}
/**
* Gets the out-of-bag estimate distribution across categories for the
* training example with the given index.
*
* @param index
* The 0-based index for the training example.
* @return
* The distribution over output categories for the out-of-bag
* estimate for the example at that index. May be empty.
*/
public abstract DataDistribution getOutOfBagEstimate(
final int index);
@Override
public void stepEnded(
final IterativeAlgorithm algorithm)
{
// Go through the data and update the values for the data that was
// not in the bag.
final int dataSize = learner.getData().size();
final int[] dataInBag = learner.getDataInBag();
for (int i = 0; i < dataSize; i++)
{
if (dataInBag[i] <= 0)
{
// Get the actual category.
final CategoryType actual =
learner.getExample(i).getOutput();
// Get the out-of-bag-votes to determine the ensemble's
// guess.
final DataDistribution outOfBagVotes =
this.getOutOfBagEstimate(i);
final CategoryType ensembleGuess =
outOfBagVotes.getMaxValueKey();
// Update whether or not the ensemble is getting this item
// correct.
final boolean oldEnsembleCorrect = this.outOfBagCorrect[i];
final boolean newEnsembleCorrect =
ObjectUtil.equalsSafe(actual, ensembleGuess);
if (oldEnsembleCorrect != newEnsembleCorrect)
{
// Save the new correctness.
this.outOfBagCorrect[i] = newEnsembleCorrect;
// Update the error count.
if (newEnsembleCorrect)
{
this.outOfBagErrorCount--;
}
else
{
this.outOfBagErrorCount++;
}
}
}
}
// Compute the out-of-bag error rate for the ensemble.
final double outOfBagEnsembleErrorRate =
(double) this.outOfBagErrorCount / this.learner.getData().size();
// Store this and compute the smoothed error rate.
this.rawErrorRates.add(outOfBagEnsembleErrorRate);
this.smoothingBuffer.add(outOfBagEnsembleErrorRate);
final double smoothedErrorRate =
UnivariateStatisticsUtil.computeMean(this.smoothingBuffer);
this.smoothedErrorRates.add(smoothedErrorRate);
// See if the algorithm is still making progress or not. Once the
// smoothed error rate stops improving, its time to stop.
if (smoothedErrorRate >= this.previousSmoothedErrorRate)
{
// Stop the learning since it is no longer improving.
this.learner.stop();
// Now we need to figure out where the ensemble had the best
// performance from the smoothing buffer
final int ensembleSize = this.rawErrorRates.size();
int bestIndex = 0;
double bestRawErrorRate = Double.MAX_VALUE;
for (int i = 0; i < this.smoothingBuffer.size(); i++)
{
// Walk the ensemble backwards.
final int index = ensembleSize - i - 1;
final double rawErrorRate = this.rawErrorRates.get(index);
if (rawErrorRate <= bestRawErrorRate)
{
bestIndex = index;
bestRawErrorRate = rawErrorRate;
}
}
// Now that we know which index was the best, we need to rewind
// the ensemble to that point by removing all the members
// added after it.
for (int i = ensembleSize - 1; i > bestIndex; i--)
{
learner.getResult().members.remove(i);
}
}
// Save the smoothed error rate.
this.previousSmoothedErrorRate = smoothedErrorRate;
}
/**
* Gets the size of the smoothing window.
*
* @return
* The size of the smoothing window.
*/
public int getSmoothingWindowSize()
{
return this.smoothingWindowSize;
}
/**
* Sets the smoothing window size.
*
* @param smoothingWindowSize
* The smoothing window size. Must be positive.
*/
public void setSmoothingWindowSize(
final int smoothingWindowSize)
{
if (smoothingWindowSize < 0)
{
throw new IllegalArgumentException(
"smoothingWindowSize must be positive.");
}
this.smoothingWindowSize = smoothingWindowSize;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy