gov.sandia.cognition.learning.algorithm.ensemble.BaggingCategorizerLearner Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: BaggingCategorizerLearner.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright November 26, 2009, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.ensemble;
import gov.sandia.cognition.algorithm.IterativeAlgorithm;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.statistics.DataDistribution;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Random;
import java.util.Set;
/**
* Learns an categorization ensemble by randomly sampling with replacement
* (duplicates allowed) some percentage of the size of the data (defaults to
* 100%) on each iteration to train a new ensemble member. The random sample is
* referred to as a bag. Each learned ensemble member is given equal weight.
* The idea here is that randomly sampling from the data and learning a
* categorizer that has high variance (such as a decision tree) with respect to
* the input data, one can improve the performance of that
*
* By default, the algorithm runs the maxIterations number of steps to create
* that number of ensemble members. However, one can also use out-of-bag (OOB)
* error on each iteration to determine a stopping criteria. The OOB error is
* determined by looking at the performance of the categorizer on the examples
* that it has not seen.
*
* @param
* The input type for supervised learning. Passed on to the internal
* learning algorithm. Also the input type for the learned ensemble.
* @param
* The output type for supervised learning. Passed on to the internal
* learning algorithm. Also the output type of the learned ensemble.
* @author Justin Basilico
* @since 3.0
*/
@PublicationReference(
title="Bagging Predictors",
author="Leo Breiman",
year=1996,
type=PublicationType.Journal,
publication="Machine Learning",
pages={123, 140},
url="http://www.springerlink.com/index/L4780124W2874025.pdf")
public class BaggingCategorizerLearner
extends AbstractBaggingLearner, WeightedVotingCategorizerEnsemble>>
implements BagBasedCategorizerEnsembleLearner
{
/**
* Creates a new instance of BaggingCategorizerLearner.
*/
public BaggingCategorizerLearner()
{
this(null);
}
/**
* Creates a new instance of BaggingCategorizerLearner.
*
* @param learner
* The learner to use to create the categorizer on each iteration.
*/
public BaggingCategorizerLearner(
final BatchLearner super Collection extends InputOutputPair extends InputType, CategoryType>>, ? extends Evaluator super InputType, ? extends CategoryType>> learner)
{
this(learner, DEFAULT_MAX_ITERATIONS, DEFAULT_PERCENT_TO_SAMPLE, new Random());
}
/**
* Creates a new instance of BaggingCategorizerLearner.
*
* @param learner
* The learner to use to create the categorizer on each iteration.
* @param maxIterations
* The maximum number of iterations to run for, which is also the
* number of learners to create.
* @param percentToSample
* The percentage of the total size of the data to sample on each
* iteration. Must be positive.
* @param random
* The random number generator to use.
*/
public BaggingCategorizerLearner(
final BatchLearner super Collection extends InputOutputPair extends InputType, CategoryType>>, ? extends Evaluator super InputType, ? extends CategoryType>> learner,
final int maxIterations,
final double percentToSample,
final Random random)
{
super(learner, maxIterations, percentToSample, random);
}
@Override
protected WeightedVotingCategorizerEnsemble> createInitialEnsemble()
{
final Set categories =
DatasetUtil.findUniqueOutputs(this.getData());
return new WeightedVotingCategorizerEnsemble>(
categories);
}
@Override
protected void addEnsembleMember(
final Evaluator super InputType, ? extends CategoryType> member)
{
// Add the categorizer to the ensemble and give it equal weight.
this.ensemble.add(member, 1.0);
}
@Override
public int[] getDataInBag()
{
return this.dataInBag;
}
@Override
public InputOutputPair extends InputType, CategoryType> getExample(
final int index)
{
return this.dataList.get(index);
}
/**
* Implements a stopping criteria for bagging that uses the out-of-bag
* error to determine when to stop learning the ensemble. It tracks the
* out-of-bag error rate of the ensemble and keeps it in a given smoothing
* window. Once the smoothed error rate stops decreasing, it stops learning
* and removes all of the ensemble members back to the one that had the
* minimal error in that window.
*
* @param
* The input type the algorithm is learning over.
* @param
* The category type the algorithm is learning over.
*/
public static class OutOfBagErrorStoppingCriteria
extends AbstractCategorizerOutOfBagStoppingCriteria
{
/** The running estimate of the ensemble for each example where an ensemble
* member can only vote on elements that were not in the bag used to train
* it. Same size as the training data. */
protected transient ArrayList> outOfBagEstimates;
/**
* Creates a new {@code OutOfBagErrorStoppingCriteria}.
*/
public OutOfBagErrorStoppingCriteria()
{
this(DEFAULT_SMOOTHING_WINDOW_SIZE);
}
/**
* Creates a new {@code OutOfBagErrorStoppingCriteria} with the given
* smoothing window size.
*
* @param smoothingWindowSize
* The smoothing window size to use. Must be positive.
*/
public OutOfBagErrorStoppingCriteria(
final int smoothingWindowSize)
{
super(smoothingWindowSize);
}
@SuppressWarnings("unchecked")
@Override
public void algorithmStarted(
final IterativeAlgorithm algorithm)
{
super.algorithmStarted(algorithm);
final int dataSize = this.learner.getData().size();
this.outOfBagEstimates = new ArrayList<>(dataSize);
for (int i = 0; i < dataSize; i++)
{
this.outOfBagEstimates.add(new DefaultDataDistribution<>(2));
}
}
@Override
public void algorithmEnded(
final IterativeAlgorithm algorithm)
{
super.algorithmEnded(algorithm);
this.outOfBagEstimates = null;
}
@Override
public DataDistribution getOutOfBagEstimate(
final int index)
{
return this.outOfBagEstimates.get(index);
}
/**
* Updates the out-of-bag estimates that this ensemble keeps.
*/
protected void updateOutOfBagEstimates()
{
final WeightedValue extends Evaluator super InputType, ? extends CategoryType>> weightedMember =
CollectionUtil.getLast(this.learner.getResult().getMembers());
final double weight = weightedMember.getWeight();
final Evaluator super InputType, ? extends CategoryType> member =
weightedMember.getValue();
final int[] dataInBag = this.learner.getDataInBag();
// Go through the data and update the values for the data that was
// not in the bag.
final int dataSize = dataInBag.length;
for (int i = 0; i < dataSize; i++)
{
if (dataInBag[i] <= 0)
{
final InputOutputPair extends InputType, CategoryType> example =
this.learner.getExample(i);
final CategoryType memberGuess = member.evaluate(
example.getInput());
this.outOfBagEstimates.get(i).increment(
memberGuess, weight);
}
}
}
@Override
public void stepEnded(
final IterativeAlgorithm algorithm)
{
// First update all the estimates since they're used by the super
// class.
this.updateOutOfBagEstimates();
super.stepEnded(algorithm);
}
}
}